jpayne@7
|
1 """
|
jpayne@7
|
2 Module for using pyOpenSSL as a TLS backend. This module was relevant before
|
jpayne@7
|
3 the standard library ``ssl`` module supported SNI, but now that we've dropped
|
jpayne@7
|
4 support for Python 2.7 all relevant Python versions support SNI so
|
jpayne@7
|
5 **this module is no longer recommended**.
|
jpayne@7
|
6
|
jpayne@7
|
7 This needs the following packages installed:
|
jpayne@7
|
8
|
jpayne@7
|
9 * `pyOpenSSL`_ (tested with 16.0.0)
|
jpayne@7
|
10 * `cryptography`_ (minimum 1.3.4, from pyopenssl)
|
jpayne@7
|
11 * `idna`_ (minimum 2.0)
|
jpayne@7
|
12
|
jpayne@7
|
13 However, pyOpenSSL depends on cryptography, so while we use all three directly here we
|
jpayne@7
|
14 end up having relatively few packages required.
|
jpayne@7
|
15
|
jpayne@7
|
16 You can install them with the following command:
|
jpayne@7
|
17
|
jpayne@7
|
18 .. code-block:: bash
|
jpayne@7
|
19
|
jpayne@7
|
20 $ python -m pip install pyopenssl cryptography idna
|
jpayne@7
|
21
|
jpayne@7
|
22 To activate certificate checking, call
|
jpayne@7
|
23 :func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code
|
jpayne@7
|
24 before you begin making HTTP requests. This can be done in a ``sitecustomize``
|
jpayne@7
|
25 module, or at any other time before your application begins using ``urllib3``,
|
jpayne@7
|
26 like this:
|
jpayne@7
|
27
|
jpayne@7
|
28 .. code-block:: python
|
jpayne@7
|
29
|
jpayne@7
|
30 try:
|
jpayne@7
|
31 import urllib3.contrib.pyopenssl
|
jpayne@7
|
32 urllib3.contrib.pyopenssl.inject_into_urllib3()
|
jpayne@7
|
33 except ImportError:
|
jpayne@7
|
34 pass
|
jpayne@7
|
35
|
jpayne@7
|
36 .. _pyopenssl: https://www.pyopenssl.org
|
jpayne@7
|
37 .. _cryptography: https://cryptography.io
|
jpayne@7
|
38 .. _idna: https://github.com/kjd/idna
|
jpayne@7
|
39 """
|
jpayne@7
|
40
|
jpayne@7
|
41 from __future__ import annotations
|
jpayne@7
|
42
|
jpayne@7
|
43 import OpenSSL.SSL # type: ignore[import-untyped]
|
jpayne@7
|
44 from cryptography import x509
|
jpayne@7
|
45
|
jpayne@7
|
46 try:
|
jpayne@7
|
47 from cryptography.x509 import UnsupportedExtension # type: ignore[attr-defined]
|
jpayne@7
|
48 except ImportError:
|
jpayne@7
|
49 # UnsupportedExtension is gone in cryptography >= 2.1.0
|
jpayne@7
|
50 class UnsupportedExtension(Exception): # type: ignore[no-redef]
|
jpayne@7
|
51 pass
|
jpayne@7
|
52
|
jpayne@7
|
53
|
jpayne@7
|
54 import logging
|
jpayne@7
|
55 import ssl
|
jpayne@7
|
56 import typing
|
jpayne@7
|
57 from io import BytesIO
|
jpayne@7
|
58 from socket import socket as socket_cls
|
jpayne@7
|
59 from socket import timeout
|
jpayne@7
|
60
|
jpayne@7
|
61 from .. import util
|
jpayne@7
|
62
|
jpayne@7
|
63 if typing.TYPE_CHECKING:
|
jpayne@7
|
64 from OpenSSL.crypto import X509 # type: ignore[import-untyped]
|
jpayne@7
|
65
|
jpayne@7
|
66
|
jpayne@7
|
67 __all__ = ["inject_into_urllib3", "extract_from_urllib3"]
|
jpayne@7
|
68
|
jpayne@7
|
69 # Map from urllib3 to PyOpenSSL compatible parameter-values.
|
jpayne@7
|
70 _openssl_versions: dict[int, int] = {
|
jpayne@7
|
71 util.ssl_.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined]
|
jpayne@7
|
72 util.ssl_.PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined]
|
jpayne@7
|
73 ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
|
jpayne@7
|
74 }
|
jpayne@7
|
75
|
jpayne@7
|
76 if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"):
|
jpayne@7
|
77 _openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
|
jpayne@7
|
78
|
jpayne@7
|
79 if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"):
|
jpayne@7
|
80 _openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD
|
jpayne@7
|
81
|
jpayne@7
|
82
|
jpayne@7
|
83 _stdlib_to_openssl_verify = {
|
jpayne@7
|
84 ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE,
|
jpayne@7
|
85 ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER,
|
jpayne@7
|
86 ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
|
jpayne@7
|
87 + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
|
jpayne@7
|
88 }
|
jpayne@7
|
89 _openssl_to_stdlib_verify = {v: k for k, v in _stdlib_to_openssl_verify.items()}
|
jpayne@7
|
90
|
jpayne@7
|
91 # The SSLvX values are the most likely to be missing in the future
|
jpayne@7
|
92 # but we check them all just to be sure.
|
jpayne@7
|
93 _OP_NO_SSLv2_OR_SSLv3: int = getattr(OpenSSL.SSL, "OP_NO_SSLv2", 0) | getattr(
|
jpayne@7
|
94 OpenSSL.SSL, "OP_NO_SSLv3", 0
|
jpayne@7
|
95 )
|
jpayne@7
|
96 _OP_NO_TLSv1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1", 0)
|
jpayne@7
|
97 _OP_NO_TLSv1_1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_1", 0)
|
jpayne@7
|
98 _OP_NO_TLSv1_2: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_2", 0)
|
jpayne@7
|
99 _OP_NO_TLSv1_3: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_3", 0)
|
jpayne@7
|
100
|
jpayne@7
|
101 _openssl_to_ssl_minimum_version: dict[int, int] = {
|
jpayne@7
|
102 ssl.TLSVersion.MINIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3,
|
jpayne@7
|
103 ssl.TLSVersion.TLSv1: _OP_NO_SSLv2_OR_SSLv3,
|
jpayne@7
|
104 ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1,
|
jpayne@7
|
105 ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1,
|
jpayne@7
|
106 ssl.TLSVersion.TLSv1_3: (
|
jpayne@7
|
107 _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2
|
jpayne@7
|
108 ),
|
jpayne@7
|
109 ssl.TLSVersion.MAXIMUM_SUPPORTED: (
|
jpayne@7
|
110 _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2
|
jpayne@7
|
111 ),
|
jpayne@7
|
112 }
|
jpayne@7
|
113 _openssl_to_ssl_maximum_version: dict[int, int] = {
|
jpayne@7
|
114 ssl.TLSVersion.MINIMUM_SUPPORTED: (
|
jpayne@7
|
115 _OP_NO_SSLv2_OR_SSLv3
|
jpayne@7
|
116 | _OP_NO_TLSv1
|
jpayne@7
|
117 | _OP_NO_TLSv1_1
|
jpayne@7
|
118 | _OP_NO_TLSv1_2
|
jpayne@7
|
119 | _OP_NO_TLSv1_3
|
jpayne@7
|
120 ),
|
jpayne@7
|
121 ssl.TLSVersion.TLSv1: (
|
jpayne@7
|
122 _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3
|
jpayne@7
|
123 ),
|
jpayne@7
|
124 ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3,
|
jpayne@7
|
125 ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_3,
|
jpayne@7
|
126 ssl.TLSVersion.TLSv1_3: _OP_NO_SSLv2_OR_SSLv3,
|
jpayne@7
|
127 ssl.TLSVersion.MAXIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3,
|
jpayne@7
|
128 }
|
jpayne@7
|
129
|
jpayne@7
|
130 # OpenSSL will only write 16K at a time
|
jpayne@7
|
131 SSL_WRITE_BLOCKSIZE = 16384
|
jpayne@7
|
132
|
jpayne@7
|
133 orig_util_SSLContext = util.ssl_.SSLContext
|
jpayne@7
|
134
|
jpayne@7
|
135
|
jpayne@7
|
136 log = logging.getLogger(__name__)
|
jpayne@7
|
137
|
jpayne@7
|
138
|
jpayne@7
|
139 def inject_into_urllib3() -> None:
|
jpayne@7
|
140 "Monkey-patch urllib3 with PyOpenSSL-backed SSL-support."
|
jpayne@7
|
141
|
jpayne@7
|
142 _validate_dependencies_met()
|
jpayne@7
|
143
|
jpayne@7
|
144 util.SSLContext = PyOpenSSLContext # type: ignore[assignment]
|
jpayne@7
|
145 util.ssl_.SSLContext = PyOpenSSLContext # type: ignore[assignment]
|
jpayne@7
|
146 util.IS_PYOPENSSL = True
|
jpayne@7
|
147 util.ssl_.IS_PYOPENSSL = True
|
jpayne@7
|
148
|
jpayne@7
|
149
|
jpayne@7
|
150 def extract_from_urllib3() -> None:
|
jpayne@7
|
151 "Undo monkey-patching by :func:`inject_into_urllib3`."
|
jpayne@7
|
152
|
jpayne@7
|
153 util.SSLContext = orig_util_SSLContext
|
jpayne@7
|
154 util.ssl_.SSLContext = orig_util_SSLContext
|
jpayne@7
|
155 util.IS_PYOPENSSL = False
|
jpayne@7
|
156 util.ssl_.IS_PYOPENSSL = False
|
jpayne@7
|
157
|
jpayne@7
|
158
|
jpayne@7
|
159 def _validate_dependencies_met() -> None:
|
jpayne@7
|
160 """
|
jpayne@7
|
161 Verifies that PyOpenSSL's package-level dependencies have been met.
|
jpayne@7
|
162 Throws `ImportError` if they are not met.
|
jpayne@7
|
163 """
|
jpayne@7
|
164 # Method added in `cryptography==1.1`; not available in older versions
|
jpayne@7
|
165 from cryptography.x509.extensions import Extensions
|
jpayne@7
|
166
|
jpayne@7
|
167 if getattr(Extensions, "get_extension_for_class", None) is None:
|
jpayne@7
|
168 raise ImportError(
|
jpayne@7
|
169 "'cryptography' module missing required functionality. "
|
jpayne@7
|
170 "Try upgrading to v1.3.4 or newer."
|
jpayne@7
|
171 )
|
jpayne@7
|
172
|
jpayne@7
|
173 # pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509
|
jpayne@7
|
174 # attribute is only present on those versions.
|
jpayne@7
|
175 from OpenSSL.crypto import X509
|
jpayne@7
|
176
|
jpayne@7
|
177 x509 = X509()
|
jpayne@7
|
178 if getattr(x509, "_x509", None) is None:
|
jpayne@7
|
179 raise ImportError(
|
jpayne@7
|
180 "'pyOpenSSL' module missing required functionality. "
|
jpayne@7
|
181 "Try upgrading to v0.14 or newer."
|
jpayne@7
|
182 )
|
jpayne@7
|
183
|
jpayne@7
|
184
|
jpayne@7
|
185 def _dnsname_to_stdlib(name: str) -> str | None:
|
jpayne@7
|
186 """
|
jpayne@7
|
187 Converts a dNSName SubjectAlternativeName field to the form used by the
|
jpayne@7
|
188 standard library on the given Python version.
|
jpayne@7
|
189
|
jpayne@7
|
190 Cryptography produces a dNSName as a unicode string that was idna-decoded
|
jpayne@7
|
191 from ASCII bytes. We need to idna-encode that string to get it back, and
|
jpayne@7
|
192 then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib
|
jpayne@7
|
193 uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8).
|
jpayne@7
|
194
|
jpayne@7
|
195 If the name cannot be idna-encoded then we return None signalling that
|
jpayne@7
|
196 the name given should be skipped.
|
jpayne@7
|
197 """
|
jpayne@7
|
198
|
jpayne@7
|
199 def idna_encode(name: str) -> bytes | None:
|
jpayne@7
|
200 """
|
jpayne@7
|
201 Borrowed wholesale from the Python Cryptography Project. It turns out
|
jpayne@7
|
202 that we can't just safely call `idna.encode`: it can explode for
|
jpayne@7
|
203 wildcard names. This avoids that problem.
|
jpayne@7
|
204 """
|
jpayne@7
|
205 import idna
|
jpayne@7
|
206
|
jpayne@7
|
207 try:
|
jpayne@7
|
208 for prefix in ["*.", "."]:
|
jpayne@7
|
209 if name.startswith(prefix):
|
jpayne@7
|
210 name = name[len(prefix) :]
|
jpayne@7
|
211 return prefix.encode("ascii") + idna.encode(name)
|
jpayne@7
|
212 return idna.encode(name)
|
jpayne@7
|
213 except idna.core.IDNAError:
|
jpayne@7
|
214 return None
|
jpayne@7
|
215
|
jpayne@7
|
216 # Don't send IPv6 addresses through the IDNA encoder.
|
jpayne@7
|
217 if ":" in name:
|
jpayne@7
|
218 return name
|
jpayne@7
|
219
|
jpayne@7
|
220 encoded_name = idna_encode(name)
|
jpayne@7
|
221 if encoded_name is None:
|
jpayne@7
|
222 return None
|
jpayne@7
|
223 return encoded_name.decode("utf-8")
|
jpayne@7
|
224
|
jpayne@7
|
225
|
jpayne@7
|
226 def get_subj_alt_name(peer_cert: X509) -> list[tuple[str, str]]:
|
jpayne@7
|
227 """
|
jpayne@7
|
228 Given an PyOpenSSL certificate, provides all the subject alternative names.
|
jpayne@7
|
229 """
|
jpayne@7
|
230 cert = peer_cert.to_cryptography()
|
jpayne@7
|
231
|
jpayne@7
|
232 # We want to find the SAN extension. Ask Cryptography to locate it (it's
|
jpayne@7
|
233 # faster than looping in Python)
|
jpayne@7
|
234 try:
|
jpayne@7
|
235 ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
|
jpayne@7
|
236 except x509.ExtensionNotFound:
|
jpayne@7
|
237 # No such extension, return the empty list.
|
jpayne@7
|
238 return []
|
jpayne@7
|
239 except (
|
jpayne@7
|
240 x509.DuplicateExtension,
|
jpayne@7
|
241 UnsupportedExtension,
|
jpayne@7
|
242 x509.UnsupportedGeneralNameType,
|
jpayne@7
|
243 UnicodeError,
|
jpayne@7
|
244 ) as e:
|
jpayne@7
|
245 # A problem has been found with the quality of the certificate. Assume
|
jpayne@7
|
246 # no SAN field is present.
|
jpayne@7
|
247 log.warning(
|
jpayne@7
|
248 "A problem was encountered with the certificate that prevented "
|
jpayne@7
|
249 "urllib3 from finding the SubjectAlternativeName field. This can "
|
jpayne@7
|
250 "affect certificate validation. The error was %s",
|
jpayne@7
|
251 e,
|
jpayne@7
|
252 )
|
jpayne@7
|
253 return []
|
jpayne@7
|
254
|
jpayne@7
|
255 # We want to return dNSName and iPAddress fields. We need to cast the IPs
|
jpayne@7
|
256 # back to strings because the match_hostname function wants them as
|
jpayne@7
|
257 # strings.
|
jpayne@7
|
258 # Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8
|
jpayne@7
|
259 # decoded. This is pretty frustrating, but that's what the standard library
|
jpayne@7
|
260 # does with certificates, and so we need to attempt to do the same.
|
jpayne@7
|
261 # We also want to skip over names which cannot be idna encoded.
|
jpayne@7
|
262 names = [
|
jpayne@7
|
263 ("DNS", name)
|
jpayne@7
|
264 for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName))
|
jpayne@7
|
265 if name is not None
|
jpayne@7
|
266 ]
|
jpayne@7
|
267 names.extend(
|
jpayne@7
|
268 ("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress)
|
jpayne@7
|
269 )
|
jpayne@7
|
270
|
jpayne@7
|
271 return names
|
jpayne@7
|
272
|
jpayne@7
|
273
|
jpayne@7
|
274 class WrappedSocket:
|
jpayne@7
|
275 """API-compatibility wrapper for Python OpenSSL's Connection-class."""
|
jpayne@7
|
276
|
jpayne@7
|
277 def __init__(
|
jpayne@7
|
278 self,
|
jpayne@7
|
279 connection: OpenSSL.SSL.Connection,
|
jpayne@7
|
280 socket: socket_cls,
|
jpayne@7
|
281 suppress_ragged_eofs: bool = True,
|
jpayne@7
|
282 ) -> None:
|
jpayne@7
|
283 self.connection = connection
|
jpayne@7
|
284 self.socket = socket
|
jpayne@7
|
285 self.suppress_ragged_eofs = suppress_ragged_eofs
|
jpayne@7
|
286 self._io_refs = 0
|
jpayne@7
|
287 self._closed = False
|
jpayne@7
|
288
|
jpayne@7
|
289 def fileno(self) -> int:
|
jpayne@7
|
290 return self.socket.fileno()
|
jpayne@7
|
291
|
jpayne@7
|
292 # Copy-pasted from Python 3.5 source code
|
jpayne@7
|
293 def _decref_socketios(self) -> None:
|
jpayne@7
|
294 if self._io_refs > 0:
|
jpayne@7
|
295 self._io_refs -= 1
|
jpayne@7
|
296 if self._closed:
|
jpayne@7
|
297 self.close()
|
jpayne@7
|
298
|
jpayne@7
|
299 def recv(self, *args: typing.Any, **kwargs: typing.Any) -> bytes:
|
jpayne@7
|
300 try:
|
jpayne@7
|
301 data = self.connection.recv(*args, **kwargs)
|
jpayne@7
|
302 except OpenSSL.SSL.SysCallError as e:
|
jpayne@7
|
303 if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
|
jpayne@7
|
304 return b""
|
jpayne@7
|
305 else:
|
jpayne@7
|
306 raise OSError(e.args[0], str(e)) from e
|
jpayne@7
|
307 except OpenSSL.SSL.ZeroReturnError:
|
jpayne@7
|
308 if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
|
jpayne@7
|
309 return b""
|
jpayne@7
|
310 else:
|
jpayne@7
|
311 raise
|
jpayne@7
|
312 except OpenSSL.SSL.WantReadError as e:
|
jpayne@7
|
313 if not util.wait_for_read(self.socket, self.socket.gettimeout()):
|
jpayne@7
|
314 raise timeout("The read operation timed out") from e
|
jpayne@7
|
315 else:
|
jpayne@7
|
316 return self.recv(*args, **kwargs)
|
jpayne@7
|
317
|
jpayne@7
|
318 # TLS 1.3 post-handshake authentication
|
jpayne@7
|
319 except OpenSSL.SSL.Error as e:
|
jpayne@7
|
320 raise ssl.SSLError(f"read error: {e!r}") from e
|
jpayne@7
|
321 else:
|
jpayne@7
|
322 return data # type: ignore[no-any-return]
|
jpayne@7
|
323
|
jpayne@7
|
324 def recv_into(self, *args: typing.Any, **kwargs: typing.Any) -> int:
|
jpayne@7
|
325 try:
|
jpayne@7
|
326 return self.connection.recv_into(*args, **kwargs) # type: ignore[no-any-return]
|
jpayne@7
|
327 except OpenSSL.SSL.SysCallError as e:
|
jpayne@7
|
328 if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
|
jpayne@7
|
329 return 0
|
jpayne@7
|
330 else:
|
jpayne@7
|
331 raise OSError(e.args[0], str(e)) from e
|
jpayne@7
|
332 except OpenSSL.SSL.ZeroReturnError:
|
jpayne@7
|
333 if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
|
jpayne@7
|
334 return 0
|
jpayne@7
|
335 else:
|
jpayne@7
|
336 raise
|
jpayne@7
|
337 except OpenSSL.SSL.WantReadError as e:
|
jpayne@7
|
338 if not util.wait_for_read(self.socket, self.socket.gettimeout()):
|
jpayne@7
|
339 raise timeout("The read operation timed out") from e
|
jpayne@7
|
340 else:
|
jpayne@7
|
341 return self.recv_into(*args, **kwargs)
|
jpayne@7
|
342
|
jpayne@7
|
343 # TLS 1.3 post-handshake authentication
|
jpayne@7
|
344 except OpenSSL.SSL.Error as e:
|
jpayne@7
|
345 raise ssl.SSLError(f"read error: {e!r}") from e
|
jpayne@7
|
346
|
jpayne@7
|
347 def settimeout(self, timeout: float) -> None:
|
jpayne@7
|
348 return self.socket.settimeout(timeout)
|
jpayne@7
|
349
|
jpayne@7
|
350 def _send_until_done(self, data: bytes) -> int:
|
jpayne@7
|
351 while True:
|
jpayne@7
|
352 try:
|
jpayne@7
|
353 return self.connection.send(data) # type: ignore[no-any-return]
|
jpayne@7
|
354 except OpenSSL.SSL.WantWriteError as e:
|
jpayne@7
|
355 if not util.wait_for_write(self.socket, self.socket.gettimeout()):
|
jpayne@7
|
356 raise timeout() from e
|
jpayne@7
|
357 continue
|
jpayne@7
|
358 except OpenSSL.SSL.SysCallError as e:
|
jpayne@7
|
359 raise OSError(e.args[0], str(e)) from e
|
jpayne@7
|
360
|
jpayne@7
|
361 def sendall(self, data: bytes) -> None:
|
jpayne@7
|
362 total_sent = 0
|
jpayne@7
|
363 while total_sent < len(data):
|
jpayne@7
|
364 sent = self._send_until_done(
|
jpayne@7
|
365 data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE]
|
jpayne@7
|
366 )
|
jpayne@7
|
367 total_sent += sent
|
jpayne@7
|
368
|
jpayne@7
|
369 def shutdown(self) -> None:
|
jpayne@7
|
370 # FIXME rethrow compatible exceptions should we ever use this
|
jpayne@7
|
371 self.connection.shutdown()
|
jpayne@7
|
372
|
jpayne@7
|
373 def close(self) -> None:
|
jpayne@7
|
374 self._closed = True
|
jpayne@7
|
375 if self._io_refs <= 0:
|
jpayne@7
|
376 self._real_close()
|
jpayne@7
|
377
|
jpayne@7
|
378 def _real_close(self) -> None:
|
jpayne@7
|
379 try:
|
jpayne@7
|
380 return self.connection.close() # type: ignore[no-any-return]
|
jpayne@7
|
381 except OpenSSL.SSL.Error:
|
jpayne@7
|
382 return
|
jpayne@7
|
383
|
jpayne@7
|
384 def getpeercert(
|
jpayne@7
|
385 self, binary_form: bool = False
|
jpayne@7
|
386 ) -> dict[str, list[typing.Any]] | None:
|
jpayne@7
|
387 x509 = self.connection.get_peer_certificate()
|
jpayne@7
|
388
|
jpayne@7
|
389 if not x509:
|
jpayne@7
|
390 return x509 # type: ignore[no-any-return]
|
jpayne@7
|
391
|
jpayne@7
|
392 if binary_form:
|
jpayne@7
|
393 return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509) # type: ignore[no-any-return]
|
jpayne@7
|
394
|
jpayne@7
|
395 return {
|
jpayne@7
|
396 "subject": ((("commonName", x509.get_subject().CN),),), # type: ignore[dict-item]
|
jpayne@7
|
397 "subjectAltName": get_subj_alt_name(x509),
|
jpayne@7
|
398 }
|
jpayne@7
|
399
|
jpayne@7
|
400 def version(self) -> str:
|
jpayne@7
|
401 return self.connection.get_protocol_version_name() # type: ignore[no-any-return]
|
jpayne@7
|
402
|
jpayne@7
|
403
|
jpayne@7
|
404 WrappedSocket.makefile = socket_cls.makefile # type: ignore[attr-defined]
|
jpayne@7
|
405
|
jpayne@7
|
406
|
jpayne@7
|
407 class PyOpenSSLContext:
|
jpayne@7
|
408 """
|
jpayne@7
|
409 I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible
|
jpayne@7
|
410 for translating the interface of the standard library ``SSLContext`` object
|
jpayne@7
|
411 to calls into PyOpenSSL.
|
jpayne@7
|
412 """
|
jpayne@7
|
413
|
jpayne@7
|
414 def __init__(self, protocol: int) -> None:
|
jpayne@7
|
415 self.protocol = _openssl_versions[protocol]
|
jpayne@7
|
416 self._ctx = OpenSSL.SSL.Context(self.protocol)
|
jpayne@7
|
417 self._options = 0
|
jpayne@7
|
418 self.check_hostname = False
|
jpayne@7
|
419 self._minimum_version: int = ssl.TLSVersion.MINIMUM_SUPPORTED
|
jpayne@7
|
420 self._maximum_version: int = ssl.TLSVersion.MAXIMUM_SUPPORTED
|
jpayne@7
|
421
|
jpayne@7
|
422 @property
|
jpayne@7
|
423 def options(self) -> int:
|
jpayne@7
|
424 return self._options
|
jpayne@7
|
425
|
jpayne@7
|
426 @options.setter
|
jpayne@7
|
427 def options(self, value: int) -> None:
|
jpayne@7
|
428 self._options = value
|
jpayne@7
|
429 self._set_ctx_options()
|
jpayne@7
|
430
|
jpayne@7
|
431 @property
|
jpayne@7
|
432 def verify_mode(self) -> int:
|
jpayne@7
|
433 return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()]
|
jpayne@7
|
434
|
jpayne@7
|
435 @verify_mode.setter
|
jpayne@7
|
436 def verify_mode(self, value: ssl.VerifyMode) -> None:
|
jpayne@7
|
437 self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback)
|
jpayne@7
|
438
|
jpayne@7
|
439 def set_default_verify_paths(self) -> None:
|
jpayne@7
|
440 self._ctx.set_default_verify_paths()
|
jpayne@7
|
441
|
jpayne@7
|
442 def set_ciphers(self, ciphers: bytes | str) -> None:
|
jpayne@7
|
443 if isinstance(ciphers, str):
|
jpayne@7
|
444 ciphers = ciphers.encode("utf-8")
|
jpayne@7
|
445 self._ctx.set_cipher_list(ciphers)
|
jpayne@7
|
446
|
jpayne@7
|
447 def load_verify_locations(
|
jpayne@7
|
448 self,
|
jpayne@7
|
449 cafile: str | None = None,
|
jpayne@7
|
450 capath: str | None = None,
|
jpayne@7
|
451 cadata: bytes | None = None,
|
jpayne@7
|
452 ) -> None:
|
jpayne@7
|
453 if cafile is not None:
|
jpayne@7
|
454 cafile = cafile.encode("utf-8") # type: ignore[assignment]
|
jpayne@7
|
455 if capath is not None:
|
jpayne@7
|
456 capath = capath.encode("utf-8") # type: ignore[assignment]
|
jpayne@7
|
457 try:
|
jpayne@7
|
458 self._ctx.load_verify_locations(cafile, capath)
|
jpayne@7
|
459 if cadata is not None:
|
jpayne@7
|
460 self._ctx.load_verify_locations(BytesIO(cadata))
|
jpayne@7
|
461 except OpenSSL.SSL.Error as e:
|
jpayne@7
|
462 raise ssl.SSLError(f"unable to load trusted certificates: {e!r}") from e
|
jpayne@7
|
463
|
jpayne@7
|
464 def load_cert_chain(
|
jpayne@7
|
465 self,
|
jpayne@7
|
466 certfile: str,
|
jpayne@7
|
467 keyfile: str | None = None,
|
jpayne@7
|
468 password: str | None = None,
|
jpayne@7
|
469 ) -> None:
|
jpayne@7
|
470 try:
|
jpayne@7
|
471 self._ctx.use_certificate_chain_file(certfile)
|
jpayne@7
|
472 if password is not None:
|
jpayne@7
|
473 if not isinstance(password, bytes):
|
jpayne@7
|
474 password = password.encode("utf-8") # type: ignore[assignment]
|
jpayne@7
|
475 self._ctx.set_passwd_cb(lambda *_: password)
|
jpayne@7
|
476 self._ctx.use_privatekey_file(keyfile or certfile)
|
jpayne@7
|
477 except OpenSSL.SSL.Error as e:
|
jpayne@7
|
478 raise ssl.SSLError(f"Unable to load certificate chain: {e!r}") from e
|
jpayne@7
|
479
|
jpayne@7
|
480 def set_alpn_protocols(self, protocols: list[bytes | str]) -> None:
|
jpayne@7
|
481 protocols = [util.util.to_bytes(p, "ascii") for p in protocols]
|
jpayne@7
|
482 return self._ctx.set_alpn_protos(protocols) # type: ignore[no-any-return]
|
jpayne@7
|
483
|
jpayne@7
|
484 def wrap_socket(
|
jpayne@7
|
485 self,
|
jpayne@7
|
486 sock: socket_cls,
|
jpayne@7
|
487 server_side: bool = False,
|
jpayne@7
|
488 do_handshake_on_connect: bool = True,
|
jpayne@7
|
489 suppress_ragged_eofs: bool = True,
|
jpayne@7
|
490 server_hostname: bytes | str | None = None,
|
jpayne@7
|
491 ) -> WrappedSocket:
|
jpayne@7
|
492 cnx = OpenSSL.SSL.Connection(self._ctx, sock)
|
jpayne@7
|
493
|
jpayne@7
|
494 # If server_hostname is an IP, don't use it for SNI, per RFC6066 Section 3
|
jpayne@7
|
495 if server_hostname and not util.ssl_.is_ipaddress(server_hostname):
|
jpayne@7
|
496 if isinstance(server_hostname, str):
|
jpayne@7
|
497 server_hostname = server_hostname.encode("utf-8")
|
jpayne@7
|
498 cnx.set_tlsext_host_name(server_hostname)
|
jpayne@7
|
499
|
jpayne@7
|
500 cnx.set_connect_state()
|
jpayne@7
|
501
|
jpayne@7
|
502 while True:
|
jpayne@7
|
503 try:
|
jpayne@7
|
504 cnx.do_handshake()
|
jpayne@7
|
505 except OpenSSL.SSL.WantReadError as e:
|
jpayne@7
|
506 if not util.wait_for_read(sock, sock.gettimeout()):
|
jpayne@7
|
507 raise timeout("select timed out") from e
|
jpayne@7
|
508 continue
|
jpayne@7
|
509 except OpenSSL.SSL.Error as e:
|
jpayne@7
|
510 raise ssl.SSLError(f"bad handshake: {e!r}") from e
|
jpayne@7
|
511 break
|
jpayne@7
|
512
|
jpayne@7
|
513 return WrappedSocket(cnx, sock)
|
jpayne@7
|
514
|
jpayne@7
|
515 def _set_ctx_options(self) -> None:
|
jpayne@7
|
516 self._ctx.set_options(
|
jpayne@7
|
517 self._options
|
jpayne@7
|
518 | _openssl_to_ssl_minimum_version[self._minimum_version]
|
jpayne@7
|
519 | _openssl_to_ssl_maximum_version[self._maximum_version]
|
jpayne@7
|
520 )
|
jpayne@7
|
521
|
jpayne@7
|
522 @property
|
jpayne@7
|
523 def minimum_version(self) -> int:
|
jpayne@7
|
524 return self._minimum_version
|
jpayne@7
|
525
|
jpayne@7
|
526 @minimum_version.setter
|
jpayne@7
|
527 def minimum_version(self, minimum_version: int) -> None:
|
jpayne@7
|
528 self._minimum_version = minimum_version
|
jpayne@7
|
529 self._set_ctx_options()
|
jpayne@7
|
530
|
jpayne@7
|
531 @property
|
jpayne@7
|
532 def maximum_version(self) -> int:
|
jpayne@7
|
533 return self._maximum_version
|
jpayne@7
|
534
|
jpayne@7
|
535 @maximum_version.setter
|
jpayne@7
|
536 def maximum_version(self, maximum_version: int) -> None:
|
jpayne@7
|
537 self._maximum_version = maximum_version
|
jpayne@7
|
538 self._set_ctx_options()
|
jpayne@7
|
539
|
jpayne@7
|
540
|
jpayne@7
|
541 def _verify_callback(
|
jpayne@7
|
542 cnx: OpenSSL.SSL.Connection,
|
jpayne@7
|
543 x509: X509,
|
jpayne@7
|
544 err_no: int,
|
jpayne@7
|
545 err_depth: int,
|
jpayne@7
|
546 return_code: int,
|
jpayne@7
|
547 ) -> bool:
|
jpayne@7
|
548 return err_no == 0
|