comparison urllib3/contrib/pyopenssl.py @ 7:5eb2d5e3bf22

planemo upload for repository https://toolrepo.galaxytrakr.org/view/jpayne/bioproject_to_srr_2/556cac4fb538
author jpayne
date Sun, 05 May 2024 23:32:17 -0400
parents
children
comparison
equal deleted inserted replaced
6:b2745907b1eb 7:5eb2d5e3bf22
1 """
2 Module for using pyOpenSSL as a TLS backend. This module was relevant before
3 the standard library ``ssl`` module supported SNI, but now that we've dropped
4 support for Python 2.7 all relevant Python versions support SNI so
5 **this module is no longer recommended**.
6
7 This needs the following packages installed:
8
9 * `pyOpenSSL`_ (tested with 16.0.0)
10 * `cryptography`_ (minimum 1.3.4, from pyopenssl)
11 * `idna`_ (minimum 2.0)
12
13 However, pyOpenSSL depends on cryptography, so while we use all three directly here we
14 end up having relatively few packages required.
15
16 You can install them with the following command:
17
18 .. code-block:: bash
19
20 $ python -m pip install pyopenssl cryptography idna
21
22 To activate certificate checking, call
23 :func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code
24 before you begin making HTTP requests. This can be done in a ``sitecustomize``
25 module, or at any other time before your application begins using ``urllib3``,
26 like this:
27
28 .. code-block:: python
29
30 try:
31 import urllib3.contrib.pyopenssl
32 urllib3.contrib.pyopenssl.inject_into_urllib3()
33 except ImportError:
34 pass
35
36 .. _pyopenssl: https://www.pyopenssl.org
37 .. _cryptography: https://cryptography.io
38 .. _idna: https://github.com/kjd/idna
39 """
40
41 from __future__ import annotations
42
43 import OpenSSL.SSL # type: ignore[import-untyped]
44 from cryptography import x509
45
46 try:
47 from cryptography.x509 import UnsupportedExtension # type: ignore[attr-defined]
48 except ImportError:
49 # UnsupportedExtension is gone in cryptography >= 2.1.0
50 class UnsupportedExtension(Exception): # type: ignore[no-redef]
51 pass
52
53
54 import logging
55 import ssl
56 import typing
57 from io import BytesIO
58 from socket import socket as socket_cls
59 from socket import timeout
60
61 from .. import util
62
63 if typing.TYPE_CHECKING:
64 from OpenSSL.crypto import X509 # type: ignore[import-untyped]
65
66
67 __all__ = ["inject_into_urllib3", "extract_from_urllib3"]
68
69 # Map from urllib3 to PyOpenSSL compatible parameter-values.
70 _openssl_versions: dict[int, int] = {
71 util.ssl_.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined]
72 util.ssl_.PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined]
73 ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
74 }
75
76 if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"):
77 _openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
78
79 if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"):
80 _openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD
81
82
83 _stdlib_to_openssl_verify = {
84 ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE,
85 ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER,
86 ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
87 + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
88 }
89 _openssl_to_stdlib_verify = {v: k for k, v in _stdlib_to_openssl_verify.items()}
90
91 # The SSLvX values are the most likely to be missing in the future
92 # but we check them all just to be sure.
93 _OP_NO_SSLv2_OR_SSLv3: int = getattr(OpenSSL.SSL, "OP_NO_SSLv2", 0) | getattr(
94 OpenSSL.SSL, "OP_NO_SSLv3", 0
95 )
96 _OP_NO_TLSv1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1", 0)
97 _OP_NO_TLSv1_1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_1", 0)
98 _OP_NO_TLSv1_2: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_2", 0)
99 _OP_NO_TLSv1_3: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_3", 0)
100
101 _openssl_to_ssl_minimum_version: dict[int, int] = {
102 ssl.TLSVersion.MINIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3,
103 ssl.TLSVersion.TLSv1: _OP_NO_SSLv2_OR_SSLv3,
104 ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1,
105 ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1,
106 ssl.TLSVersion.TLSv1_3: (
107 _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2
108 ),
109 ssl.TLSVersion.MAXIMUM_SUPPORTED: (
110 _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2
111 ),
112 }
113 _openssl_to_ssl_maximum_version: dict[int, int] = {
114 ssl.TLSVersion.MINIMUM_SUPPORTED: (
115 _OP_NO_SSLv2_OR_SSLv3
116 | _OP_NO_TLSv1
117 | _OP_NO_TLSv1_1
118 | _OP_NO_TLSv1_2
119 | _OP_NO_TLSv1_3
120 ),
121 ssl.TLSVersion.TLSv1: (
122 _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3
123 ),
124 ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3,
125 ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_3,
126 ssl.TLSVersion.TLSv1_3: _OP_NO_SSLv2_OR_SSLv3,
127 ssl.TLSVersion.MAXIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3,
128 }
129
130 # OpenSSL will only write 16K at a time
131 SSL_WRITE_BLOCKSIZE = 16384
132
133 orig_util_SSLContext = util.ssl_.SSLContext
134
135
136 log = logging.getLogger(__name__)
137
138
139 def inject_into_urllib3() -> None:
140 "Monkey-patch urllib3 with PyOpenSSL-backed SSL-support."
141
142 _validate_dependencies_met()
143
144 util.SSLContext = PyOpenSSLContext # type: ignore[assignment]
145 util.ssl_.SSLContext = PyOpenSSLContext # type: ignore[assignment]
146 util.IS_PYOPENSSL = True
147 util.ssl_.IS_PYOPENSSL = True
148
149
150 def extract_from_urllib3() -> None:
151 "Undo monkey-patching by :func:`inject_into_urllib3`."
152
153 util.SSLContext = orig_util_SSLContext
154 util.ssl_.SSLContext = orig_util_SSLContext
155 util.IS_PYOPENSSL = False
156 util.ssl_.IS_PYOPENSSL = False
157
158
159 def _validate_dependencies_met() -> None:
160 """
161 Verifies that PyOpenSSL's package-level dependencies have been met.
162 Throws `ImportError` if they are not met.
163 """
164 # Method added in `cryptography==1.1`; not available in older versions
165 from cryptography.x509.extensions import Extensions
166
167 if getattr(Extensions, "get_extension_for_class", None) is None:
168 raise ImportError(
169 "'cryptography' module missing required functionality. "
170 "Try upgrading to v1.3.4 or newer."
171 )
172
173 # pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509
174 # attribute is only present on those versions.
175 from OpenSSL.crypto import X509
176
177 x509 = X509()
178 if getattr(x509, "_x509", None) is None:
179 raise ImportError(
180 "'pyOpenSSL' module missing required functionality. "
181 "Try upgrading to v0.14 or newer."
182 )
183
184
185 def _dnsname_to_stdlib(name: str) -> str | None:
186 """
187 Converts a dNSName SubjectAlternativeName field to the form used by the
188 standard library on the given Python version.
189
190 Cryptography produces a dNSName as a unicode string that was idna-decoded
191 from ASCII bytes. We need to idna-encode that string to get it back, and
192 then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib
193 uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8).
194
195 If the name cannot be idna-encoded then we return None signalling that
196 the name given should be skipped.
197 """
198
199 def idna_encode(name: str) -> bytes | None:
200 """
201 Borrowed wholesale from the Python Cryptography Project. It turns out
202 that we can't just safely call `idna.encode`: it can explode for
203 wildcard names. This avoids that problem.
204 """
205 import idna
206
207 try:
208 for prefix in ["*.", "."]:
209 if name.startswith(prefix):
210 name = name[len(prefix) :]
211 return prefix.encode("ascii") + idna.encode(name)
212 return idna.encode(name)
213 except idna.core.IDNAError:
214 return None
215
216 # Don't send IPv6 addresses through the IDNA encoder.
217 if ":" in name:
218 return name
219
220 encoded_name = idna_encode(name)
221 if encoded_name is None:
222 return None
223 return encoded_name.decode("utf-8")
224
225
226 def get_subj_alt_name(peer_cert: X509) -> list[tuple[str, str]]:
227 """
228 Given an PyOpenSSL certificate, provides all the subject alternative names.
229 """
230 cert = peer_cert.to_cryptography()
231
232 # We want to find the SAN extension. Ask Cryptography to locate it (it's
233 # faster than looping in Python)
234 try:
235 ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
236 except x509.ExtensionNotFound:
237 # No such extension, return the empty list.
238 return []
239 except (
240 x509.DuplicateExtension,
241 UnsupportedExtension,
242 x509.UnsupportedGeneralNameType,
243 UnicodeError,
244 ) as e:
245 # A problem has been found with the quality of the certificate. Assume
246 # no SAN field is present.
247 log.warning(
248 "A problem was encountered with the certificate that prevented "
249 "urllib3 from finding the SubjectAlternativeName field. This can "
250 "affect certificate validation. The error was %s",
251 e,
252 )
253 return []
254
255 # We want to return dNSName and iPAddress fields. We need to cast the IPs
256 # back to strings because the match_hostname function wants them as
257 # strings.
258 # Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8
259 # decoded. This is pretty frustrating, but that's what the standard library
260 # does with certificates, and so we need to attempt to do the same.
261 # We also want to skip over names which cannot be idna encoded.
262 names = [
263 ("DNS", name)
264 for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName))
265 if name is not None
266 ]
267 names.extend(
268 ("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress)
269 )
270
271 return names
272
273
274 class WrappedSocket:
275 """API-compatibility wrapper for Python OpenSSL's Connection-class."""
276
277 def __init__(
278 self,
279 connection: OpenSSL.SSL.Connection,
280 socket: socket_cls,
281 suppress_ragged_eofs: bool = True,
282 ) -> None:
283 self.connection = connection
284 self.socket = socket
285 self.suppress_ragged_eofs = suppress_ragged_eofs
286 self._io_refs = 0
287 self._closed = False
288
289 def fileno(self) -> int:
290 return self.socket.fileno()
291
292 # Copy-pasted from Python 3.5 source code
293 def _decref_socketios(self) -> None:
294 if self._io_refs > 0:
295 self._io_refs -= 1
296 if self._closed:
297 self.close()
298
299 def recv(self, *args: typing.Any, **kwargs: typing.Any) -> bytes:
300 try:
301 data = self.connection.recv(*args, **kwargs)
302 except OpenSSL.SSL.SysCallError as e:
303 if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
304 return b""
305 else:
306 raise OSError(e.args[0], str(e)) from e
307 except OpenSSL.SSL.ZeroReturnError:
308 if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
309 return b""
310 else:
311 raise
312 except OpenSSL.SSL.WantReadError as e:
313 if not util.wait_for_read(self.socket, self.socket.gettimeout()):
314 raise timeout("The read operation timed out") from e
315 else:
316 return self.recv(*args, **kwargs)
317
318 # TLS 1.3 post-handshake authentication
319 except OpenSSL.SSL.Error as e:
320 raise ssl.SSLError(f"read error: {e!r}") from e
321 else:
322 return data # type: ignore[no-any-return]
323
324 def recv_into(self, *args: typing.Any, **kwargs: typing.Any) -> int:
325 try:
326 return self.connection.recv_into(*args, **kwargs) # type: ignore[no-any-return]
327 except OpenSSL.SSL.SysCallError as e:
328 if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
329 return 0
330 else:
331 raise OSError(e.args[0], str(e)) from e
332 except OpenSSL.SSL.ZeroReturnError:
333 if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
334 return 0
335 else:
336 raise
337 except OpenSSL.SSL.WantReadError as e:
338 if not util.wait_for_read(self.socket, self.socket.gettimeout()):
339 raise timeout("The read operation timed out") from e
340 else:
341 return self.recv_into(*args, **kwargs)
342
343 # TLS 1.3 post-handshake authentication
344 except OpenSSL.SSL.Error as e:
345 raise ssl.SSLError(f"read error: {e!r}") from e
346
347 def settimeout(self, timeout: float) -> None:
348 return self.socket.settimeout(timeout)
349
350 def _send_until_done(self, data: bytes) -> int:
351 while True:
352 try:
353 return self.connection.send(data) # type: ignore[no-any-return]
354 except OpenSSL.SSL.WantWriteError as e:
355 if not util.wait_for_write(self.socket, self.socket.gettimeout()):
356 raise timeout() from e
357 continue
358 except OpenSSL.SSL.SysCallError as e:
359 raise OSError(e.args[0], str(e)) from e
360
361 def sendall(self, data: bytes) -> None:
362 total_sent = 0
363 while total_sent < len(data):
364 sent = self._send_until_done(
365 data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE]
366 )
367 total_sent += sent
368
369 def shutdown(self) -> None:
370 # FIXME rethrow compatible exceptions should we ever use this
371 self.connection.shutdown()
372
373 def close(self) -> None:
374 self._closed = True
375 if self._io_refs <= 0:
376 self._real_close()
377
378 def _real_close(self) -> None:
379 try:
380 return self.connection.close() # type: ignore[no-any-return]
381 except OpenSSL.SSL.Error:
382 return
383
384 def getpeercert(
385 self, binary_form: bool = False
386 ) -> dict[str, list[typing.Any]] | None:
387 x509 = self.connection.get_peer_certificate()
388
389 if not x509:
390 return x509 # type: ignore[no-any-return]
391
392 if binary_form:
393 return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509) # type: ignore[no-any-return]
394
395 return {
396 "subject": ((("commonName", x509.get_subject().CN),),), # type: ignore[dict-item]
397 "subjectAltName": get_subj_alt_name(x509),
398 }
399
400 def version(self) -> str:
401 return self.connection.get_protocol_version_name() # type: ignore[no-any-return]
402
403
404 WrappedSocket.makefile = socket_cls.makefile # type: ignore[attr-defined]
405
406
407 class PyOpenSSLContext:
408 """
409 I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible
410 for translating the interface of the standard library ``SSLContext`` object
411 to calls into PyOpenSSL.
412 """
413
414 def __init__(self, protocol: int) -> None:
415 self.protocol = _openssl_versions[protocol]
416 self._ctx = OpenSSL.SSL.Context(self.protocol)
417 self._options = 0
418 self.check_hostname = False
419 self._minimum_version: int = ssl.TLSVersion.MINIMUM_SUPPORTED
420 self._maximum_version: int = ssl.TLSVersion.MAXIMUM_SUPPORTED
421
422 @property
423 def options(self) -> int:
424 return self._options
425
426 @options.setter
427 def options(self, value: int) -> None:
428 self._options = value
429 self._set_ctx_options()
430
431 @property
432 def verify_mode(self) -> int:
433 return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()]
434
435 @verify_mode.setter
436 def verify_mode(self, value: ssl.VerifyMode) -> None:
437 self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback)
438
439 def set_default_verify_paths(self) -> None:
440 self._ctx.set_default_verify_paths()
441
442 def set_ciphers(self, ciphers: bytes | str) -> None:
443 if isinstance(ciphers, str):
444 ciphers = ciphers.encode("utf-8")
445 self._ctx.set_cipher_list(ciphers)
446
447 def load_verify_locations(
448 self,
449 cafile: str | None = None,
450 capath: str | None = None,
451 cadata: bytes | None = None,
452 ) -> None:
453 if cafile is not None:
454 cafile = cafile.encode("utf-8") # type: ignore[assignment]
455 if capath is not None:
456 capath = capath.encode("utf-8") # type: ignore[assignment]
457 try:
458 self._ctx.load_verify_locations(cafile, capath)
459 if cadata is not None:
460 self._ctx.load_verify_locations(BytesIO(cadata))
461 except OpenSSL.SSL.Error as e:
462 raise ssl.SSLError(f"unable to load trusted certificates: {e!r}") from e
463
464 def load_cert_chain(
465 self,
466 certfile: str,
467 keyfile: str | None = None,
468 password: str | None = None,
469 ) -> None:
470 try:
471 self._ctx.use_certificate_chain_file(certfile)
472 if password is not None:
473 if not isinstance(password, bytes):
474 password = password.encode("utf-8") # type: ignore[assignment]
475 self._ctx.set_passwd_cb(lambda *_: password)
476 self._ctx.use_privatekey_file(keyfile or certfile)
477 except OpenSSL.SSL.Error as e:
478 raise ssl.SSLError(f"Unable to load certificate chain: {e!r}") from e
479
480 def set_alpn_protocols(self, protocols: list[bytes | str]) -> None:
481 protocols = [util.util.to_bytes(p, "ascii") for p in protocols]
482 return self._ctx.set_alpn_protos(protocols) # type: ignore[no-any-return]
483
484 def wrap_socket(
485 self,
486 sock: socket_cls,
487 server_side: bool = False,
488 do_handshake_on_connect: bool = True,
489 suppress_ragged_eofs: bool = True,
490 server_hostname: bytes | str | None = None,
491 ) -> WrappedSocket:
492 cnx = OpenSSL.SSL.Connection(self._ctx, sock)
493
494 # If server_hostname is an IP, don't use it for SNI, per RFC6066 Section 3
495 if server_hostname and not util.ssl_.is_ipaddress(server_hostname):
496 if isinstance(server_hostname, str):
497 server_hostname = server_hostname.encode("utf-8")
498 cnx.set_tlsext_host_name(server_hostname)
499
500 cnx.set_connect_state()
501
502 while True:
503 try:
504 cnx.do_handshake()
505 except OpenSSL.SSL.WantReadError as e:
506 if not util.wait_for_read(sock, sock.gettimeout()):
507 raise timeout("select timed out") from e
508 continue
509 except OpenSSL.SSL.Error as e:
510 raise ssl.SSLError(f"bad handshake: {e!r}") from e
511 break
512
513 return WrappedSocket(cnx, sock)
514
515 def _set_ctx_options(self) -> None:
516 self._ctx.set_options(
517 self._options
518 | _openssl_to_ssl_minimum_version[self._minimum_version]
519 | _openssl_to_ssl_maximum_version[self._maximum_version]
520 )
521
522 @property
523 def minimum_version(self) -> int:
524 return self._minimum_version
525
526 @minimum_version.setter
527 def minimum_version(self, minimum_version: int) -> None:
528 self._minimum_version = minimum_version
529 self._set_ctx_options()
530
531 @property
532 def maximum_version(self) -> int:
533 return self._maximum_version
534
535 @maximum_version.setter
536 def maximum_version(self, maximum_version: int) -> None:
537 self._maximum_version = maximum_version
538 self._set_ctx_options()
539
540
541 def _verify_callback(
542 cnx: OpenSSL.SSL.Connection,
543 x509: X509,
544 err_no: int,
545 err_depth: int,
546 return_code: int,
547 ) -> bool:
548 return err_no == 0