annotate urllib3/util/ssltransport.py @ 8:832f269deeb0

planemo upload for repository https://toolrepo.galaxytrakr.org/view/jpayne/bioproject_to_srr_2/556cac4fb538
author jpayne
date Sun, 05 May 2024 23:47:10 -0400
parents 5eb2d5e3bf22
children
rev   line source
jpayne@7 1 from __future__ import annotations
jpayne@7 2
jpayne@7 3 import io
jpayne@7 4 import socket
jpayne@7 5 import ssl
jpayne@7 6 import typing
jpayne@7 7
jpayne@7 8 from ..exceptions import ProxySchemeUnsupported
jpayne@7 9
jpayne@7 10 if typing.TYPE_CHECKING:
jpayne@7 11 from typing import Literal
jpayne@7 12
jpayne@7 13 from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
jpayne@7 14
jpayne@7 15
jpayne@7 16 _SelfT = typing.TypeVar("_SelfT", bound="SSLTransport")
jpayne@7 17 _WriteBuffer = typing.Union[bytearray, memoryview]
jpayne@7 18 _ReturnValue = typing.TypeVar("_ReturnValue")
jpayne@7 19
jpayne@7 20 SSL_BLOCKSIZE = 16384
jpayne@7 21
jpayne@7 22
jpayne@7 23 class SSLTransport:
jpayne@7 24 """
jpayne@7 25 The SSLTransport wraps an existing socket and establishes an SSL connection.
jpayne@7 26
jpayne@7 27 Contrary to Python's implementation of SSLSocket, it allows you to chain
jpayne@7 28 multiple TLS connections together. It's particularly useful if you need to
jpayne@7 29 implement TLS within TLS.
jpayne@7 30
jpayne@7 31 The class supports most of the socket API operations.
jpayne@7 32 """
jpayne@7 33
jpayne@7 34 @staticmethod
jpayne@7 35 def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
jpayne@7 36 """
jpayne@7 37 Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
jpayne@7 38 for TLS in TLS.
jpayne@7 39
jpayne@7 40 The only requirement is that the ssl_context provides the 'wrap_bio'
jpayne@7 41 methods.
jpayne@7 42 """
jpayne@7 43
jpayne@7 44 if not hasattr(ssl_context, "wrap_bio"):
jpayne@7 45 raise ProxySchemeUnsupported(
jpayne@7 46 "TLS in TLS requires SSLContext.wrap_bio() which isn't "
jpayne@7 47 "available on non-native SSLContext"
jpayne@7 48 )
jpayne@7 49
jpayne@7 50 def __init__(
jpayne@7 51 self,
jpayne@7 52 socket: socket.socket,
jpayne@7 53 ssl_context: ssl.SSLContext,
jpayne@7 54 server_hostname: str | None = None,
jpayne@7 55 suppress_ragged_eofs: bool = True,
jpayne@7 56 ) -> None:
jpayne@7 57 """
jpayne@7 58 Create an SSLTransport around socket using the provided ssl_context.
jpayne@7 59 """
jpayne@7 60 self.incoming = ssl.MemoryBIO()
jpayne@7 61 self.outgoing = ssl.MemoryBIO()
jpayne@7 62
jpayne@7 63 self.suppress_ragged_eofs = suppress_ragged_eofs
jpayne@7 64 self.socket = socket
jpayne@7 65
jpayne@7 66 self.sslobj = ssl_context.wrap_bio(
jpayne@7 67 self.incoming, self.outgoing, server_hostname=server_hostname
jpayne@7 68 )
jpayne@7 69
jpayne@7 70 # Perform initial handshake.
jpayne@7 71 self._ssl_io_loop(self.sslobj.do_handshake)
jpayne@7 72
jpayne@7 73 def __enter__(self: _SelfT) -> _SelfT:
jpayne@7 74 return self
jpayne@7 75
jpayne@7 76 def __exit__(self, *_: typing.Any) -> None:
jpayne@7 77 self.close()
jpayne@7 78
jpayne@7 79 def fileno(self) -> int:
jpayne@7 80 return self.socket.fileno()
jpayne@7 81
jpayne@7 82 def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
jpayne@7 83 return self._wrap_ssl_read(len, buffer)
jpayne@7 84
jpayne@7 85 def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
jpayne@7 86 if flags != 0:
jpayne@7 87 raise ValueError("non-zero flags not allowed in calls to recv")
jpayne@7 88 return self._wrap_ssl_read(buflen)
jpayne@7 89
jpayne@7 90 def recv_into(
jpayne@7 91 self,
jpayne@7 92 buffer: _WriteBuffer,
jpayne@7 93 nbytes: int | None = None,
jpayne@7 94 flags: int = 0,
jpayne@7 95 ) -> None | int | bytes:
jpayne@7 96 if flags != 0:
jpayne@7 97 raise ValueError("non-zero flags not allowed in calls to recv_into")
jpayne@7 98 if nbytes is None:
jpayne@7 99 nbytes = len(buffer)
jpayne@7 100 return self.read(nbytes, buffer)
jpayne@7 101
jpayne@7 102 def sendall(self, data: bytes, flags: int = 0) -> None:
jpayne@7 103 if flags != 0:
jpayne@7 104 raise ValueError("non-zero flags not allowed in calls to sendall")
jpayne@7 105 count = 0
jpayne@7 106 with memoryview(data) as view, view.cast("B") as byte_view:
jpayne@7 107 amount = len(byte_view)
jpayne@7 108 while count < amount:
jpayne@7 109 v = self.send(byte_view[count:])
jpayne@7 110 count += v
jpayne@7 111
jpayne@7 112 def send(self, data: bytes, flags: int = 0) -> int:
jpayne@7 113 if flags != 0:
jpayne@7 114 raise ValueError("non-zero flags not allowed in calls to send")
jpayne@7 115 return self._ssl_io_loop(self.sslobj.write, data)
jpayne@7 116
jpayne@7 117 def makefile(
jpayne@7 118 self,
jpayne@7 119 mode: str,
jpayne@7 120 buffering: int | None = None,
jpayne@7 121 *,
jpayne@7 122 encoding: str | None = None,
jpayne@7 123 errors: str | None = None,
jpayne@7 124 newline: str | None = None,
jpayne@7 125 ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
jpayne@7 126 """
jpayne@7 127 Python's httpclient uses makefile and buffered io when reading HTTP
jpayne@7 128 messages and we need to support it.
jpayne@7 129
jpayne@7 130 This is unfortunately a copy and paste of socket.py makefile with small
jpayne@7 131 changes to point to the socket directly.
jpayne@7 132 """
jpayne@7 133 if not set(mode) <= {"r", "w", "b"}:
jpayne@7 134 raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
jpayne@7 135
jpayne@7 136 writing = "w" in mode
jpayne@7 137 reading = "r" in mode or not writing
jpayne@7 138 assert reading or writing
jpayne@7 139 binary = "b" in mode
jpayne@7 140 rawmode = ""
jpayne@7 141 if reading:
jpayne@7 142 rawmode += "r"
jpayne@7 143 if writing:
jpayne@7 144 rawmode += "w"
jpayne@7 145 raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type]
jpayne@7 146 self.socket._io_refs += 1 # type: ignore[attr-defined]
jpayne@7 147 if buffering is None:
jpayne@7 148 buffering = -1
jpayne@7 149 if buffering < 0:
jpayne@7 150 buffering = io.DEFAULT_BUFFER_SIZE
jpayne@7 151 if buffering == 0:
jpayne@7 152 if not binary:
jpayne@7 153 raise ValueError("unbuffered streams must be binary")
jpayne@7 154 return raw
jpayne@7 155 buffer: typing.BinaryIO
jpayne@7 156 if reading and writing:
jpayne@7 157 buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment]
jpayne@7 158 elif reading:
jpayne@7 159 buffer = io.BufferedReader(raw, buffering)
jpayne@7 160 else:
jpayne@7 161 assert writing
jpayne@7 162 buffer = io.BufferedWriter(raw, buffering)
jpayne@7 163 if binary:
jpayne@7 164 return buffer
jpayne@7 165 text = io.TextIOWrapper(buffer, encoding, errors, newline)
jpayne@7 166 text.mode = mode # type: ignore[misc]
jpayne@7 167 return text
jpayne@7 168
jpayne@7 169 def unwrap(self) -> None:
jpayne@7 170 self._ssl_io_loop(self.sslobj.unwrap)
jpayne@7 171
jpayne@7 172 def close(self) -> None:
jpayne@7 173 self.socket.close()
jpayne@7 174
jpayne@7 175 @typing.overload
jpayne@7 176 def getpeercert(
jpayne@7 177 self, binary_form: Literal[False] = ...
jpayne@7 178 ) -> _TYPE_PEER_CERT_RET_DICT | None:
jpayne@7 179 ...
jpayne@7 180
jpayne@7 181 @typing.overload
jpayne@7 182 def getpeercert(self, binary_form: Literal[True]) -> bytes | None:
jpayne@7 183 ...
jpayne@7 184
jpayne@7 185 def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
jpayne@7 186 return self.sslobj.getpeercert(binary_form) # type: ignore[return-value]
jpayne@7 187
jpayne@7 188 def version(self) -> str | None:
jpayne@7 189 return self.sslobj.version()
jpayne@7 190
jpayne@7 191 def cipher(self) -> tuple[str, str, int] | None:
jpayne@7 192 return self.sslobj.cipher()
jpayne@7 193
jpayne@7 194 def selected_alpn_protocol(self) -> str | None:
jpayne@7 195 return self.sslobj.selected_alpn_protocol()
jpayne@7 196
jpayne@7 197 def selected_npn_protocol(self) -> str | None:
jpayne@7 198 return self.sslobj.selected_npn_protocol()
jpayne@7 199
jpayne@7 200 def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
jpayne@7 201 return self.sslobj.shared_ciphers()
jpayne@7 202
jpayne@7 203 def compression(self) -> str | None:
jpayne@7 204 return self.sslobj.compression()
jpayne@7 205
jpayne@7 206 def settimeout(self, value: float | None) -> None:
jpayne@7 207 self.socket.settimeout(value)
jpayne@7 208
jpayne@7 209 def gettimeout(self) -> float | None:
jpayne@7 210 return self.socket.gettimeout()
jpayne@7 211
jpayne@7 212 def _decref_socketios(self) -> None:
jpayne@7 213 self.socket._decref_socketios() # type: ignore[attr-defined]
jpayne@7 214
jpayne@7 215 def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
jpayne@7 216 try:
jpayne@7 217 return self._ssl_io_loop(self.sslobj.read, len, buffer)
jpayne@7 218 except ssl.SSLError as e:
jpayne@7 219 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
jpayne@7 220 return 0 # eof, return 0.
jpayne@7 221 else:
jpayne@7 222 raise
jpayne@7 223
jpayne@7 224 # func is sslobj.do_handshake or sslobj.unwrap
jpayne@7 225 @typing.overload
jpayne@7 226 def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None:
jpayne@7 227 ...
jpayne@7 228
jpayne@7 229 # func is sslobj.write, arg1 is data
jpayne@7 230 @typing.overload
jpayne@7 231 def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int:
jpayne@7 232 ...
jpayne@7 233
jpayne@7 234 # func is sslobj.read, arg1 is len, arg2 is buffer
jpayne@7 235 @typing.overload
jpayne@7 236 def _ssl_io_loop(
jpayne@7 237 self,
jpayne@7 238 func: typing.Callable[[int, bytearray | None], bytes],
jpayne@7 239 arg1: int,
jpayne@7 240 arg2: bytearray | None,
jpayne@7 241 ) -> bytes:
jpayne@7 242 ...
jpayne@7 243
jpayne@7 244 def _ssl_io_loop(
jpayne@7 245 self,
jpayne@7 246 func: typing.Callable[..., _ReturnValue],
jpayne@7 247 arg1: None | bytes | int = None,
jpayne@7 248 arg2: bytearray | None = None,
jpayne@7 249 ) -> _ReturnValue:
jpayne@7 250 """Performs an I/O loop between incoming/outgoing and the socket."""
jpayne@7 251 should_loop = True
jpayne@7 252 ret = None
jpayne@7 253
jpayne@7 254 while should_loop:
jpayne@7 255 errno = None
jpayne@7 256 try:
jpayne@7 257 if arg1 is None and arg2 is None:
jpayne@7 258 ret = func()
jpayne@7 259 elif arg2 is None:
jpayne@7 260 ret = func(arg1)
jpayne@7 261 else:
jpayne@7 262 ret = func(arg1, arg2)
jpayne@7 263 except ssl.SSLError as e:
jpayne@7 264 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
jpayne@7 265 # WANT_READ, and WANT_WRITE are expected, others are not.
jpayne@7 266 raise e
jpayne@7 267 errno = e.errno
jpayne@7 268
jpayne@7 269 buf = self.outgoing.read()
jpayne@7 270 self.socket.sendall(buf)
jpayne@7 271
jpayne@7 272 if errno is None:
jpayne@7 273 should_loop = False
jpayne@7 274 elif errno == ssl.SSL_ERROR_WANT_READ:
jpayne@7 275 buf = self.socket.recv(SSL_BLOCKSIZE)
jpayne@7 276 if buf:
jpayne@7 277 self.incoming.write(buf)
jpayne@7 278 else:
jpayne@7 279 self.incoming.write_eof()
jpayne@7 280 return typing.cast(_ReturnValue, ret)