jpayne@7: from __future__ import annotations jpayne@7: jpayne@7: import io jpayne@7: import socket jpayne@7: import ssl jpayne@7: import typing jpayne@7: jpayne@7: from ..exceptions import ProxySchemeUnsupported jpayne@7: jpayne@7: if typing.TYPE_CHECKING: jpayne@7: from typing import Literal jpayne@7: jpayne@7: from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT jpayne@7: jpayne@7: jpayne@7: _SelfT = typing.TypeVar("_SelfT", bound="SSLTransport") jpayne@7: _WriteBuffer = typing.Union[bytearray, memoryview] jpayne@7: _ReturnValue = typing.TypeVar("_ReturnValue") jpayne@7: jpayne@7: SSL_BLOCKSIZE = 16384 jpayne@7: jpayne@7: jpayne@7: class SSLTransport: jpayne@7: """ jpayne@7: The SSLTransport wraps an existing socket and establishes an SSL connection. jpayne@7: jpayne@7: Contrary to Python's implementation of SSLSocket, it allows you to chain jpayne@7: multiple TLS connections together. It's particularly useful if you need to jpayne@7: implement TLS within TLS. jpayne@7: jpayne@7: The class supports most of the socket API operations. jpayne@7: """ jpayne@7: jpayne@7: @staticmethod jpayne@7: def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: jpayne@7: """ jpayne@7: Raises a ProxySchemeUnsupported if the provided ssl_context can't be used jpayne@7: for TLS in TLS. jpayne@7: jpayne@7: The only requirement is that the ssl_context provides the 'wrap_bio' jpayne@7: methods. jpayne@7: """ jpayne@7: jpayne@7: if not hasattr(ssl_context, "wrap_bio"): jpayne@7: raise ProxySchemeUnsupported( jpayne@7: "TLS in TLS requires SSLContext.wrap_bio() which isn't " jpayne@7: "available on non-native SSLContext" jpayne@7: ) jpayne@7: jpayne@7: def __init__( jpayne@7: self, jpayne@7: socket: socket.socket, jpayne@7: ssl_context: ssl.SSLContext, jpayne@7: server_hostname: str | None = None, jpayne@7: suppress_ragged_eofs: bool = True, jpayne@7: ) -> None: jpayne@7: """ jpayne@7: Create an SSLTransport around socket using the provided ssl_context. jpayne@7: """ jpayne@7: self.incoming = ssl.MemoryBIO() jpayne@7: self.outgoing = ssl.MemoryBIO() jpayne@7: jpayne@7: self.suppress_ragged_eofs = suppress_ragged_eofs jpayne@7: self.socket = socket jpayne@7: jpayne@7: self.sslobj = ssl_context.wrap_bio( jpayne@7: self.incoming, self.outgoing, server_hostname=server_hostname jpayne@7: ) jpayne@7: jpayne@7: # Perform initial handshake. jpayne@7: self._ssl_io_loop(self.sslobj.do_handshake) jpayne@7: jpayne@7: def __enter__(self: _SelfT) -> _SelfT: jpayne@7: return self jpayne@7: jpayne@7: def __exit__(self, *_: typing.Any) -> None: jpayne@7: self.close() jpayne@7: jpayne@7: def fileno(self) -> int: jpayne@7: return self.socket.fileno() jpayne@7: jpayne@7: def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: jpayne@7: return self._wrap_ssl_read(len, buffer) jpayne@7: jpayne@7: def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: jpayne@7: if flags != 0: jpayne@7: raise ValueError("non-zero flags not allowed in calls to recv") jpayne@7: return self._wrap_ssl_read(buflen) jpayne@7: jpayne@7: def recv_into( jpayne@7: self, jpayne@7: buffer: _WriteBuffer, jpayne@7: nbytes: int | None = None, jpayne@7: flags: int = 0, jpayne@7: ) -> None | int | bytes: jpayne@7: if flags != 0: jpayne@7: raise ValueError("non-zero flags not allowed in calls to recv_into") jpayne@7: if nbytes is None: jpayne@7: nbytes = len(buffer) jpayne@7: return self.read(nbytes, buffer) jpayne@7: jpayne@7: def sendall(self, data: bytes, flags: int = 0) -> None: jpayne@7: if flags != 0: jpayne@7: raise ValueError("non-zero flags not allowed in calls to sendall") jpayne@7: count = 0 jpayne@7: with memoryview(data) as view, view.cast("B") as byte_view: jpayne@7: amount = len(byte_view) jpayne@7: while count < amount: jpayne@7: v = self.send(byte_view[count:]) jpayne@7: count += v jpayne@7: jpayne@7: def send(self, data: bytes, flags: int = 0) -> int: jpayne@7: if flags != 0: jpayne@7: raise ValueError("non-zero flags not allowed in calls to send") jpayne@7: return self._ssl_io_loop(self.sslobj.write, data) jpayne@7: jpayne@7: def makefile( jpayne@7: self, jpayne@7: mode: str, jpayne@7: buffering: int | None = None, jpayne@7: *, jpayne@7: encoding: str | None = None, jpayne@7: errors: str | None = None, jpayne@7: newline: str | None = None, jpayne@7: ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: jpayne@7: """ jpayne@7: Python's httpclient uses makefile and buffered io when reading HTTP jpayne@7: messages and we need to support it. jpayne@7: jpayne@7: This is unfortunately a copy and paste of socket.py makefile with small jpayne@7: changes to point to the socket directly. jpayne@7: """ jpayne@7: if not set(mode) <= {"r", "w", "b"}: jpayne@7: raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") jpayne@7: jpayne@7: writing = "w" in mode jpayne@7: reading = "r" in mode or not writing jpayne@7: assert reading or writing jpayne@7: binary = "b" in mode jpayne@7: rawmode = "" jpayne@7: if reading: jpayne@7: rawmode += "r" jpayne@7: if writing: jpayne@7: rawmode += "w" jpayne@7: raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type] jpayne@7: self.socket._io_refs += 1 # type: ignore[attr-defined] jpayne@7: if buffering is None: jpayne@7: buffering = -1 jpayne@7: if buffering < 0: jpayne@7: buffering = io.DEFAULT_BUFFER_SIZE jpayne@7: if buffering == 0: jpayne@7: if not binary: jpayne@7: raise ValueError("unbuffered streams must be binary") jpayne@7: return raw jpayne@7: buffer: typing.BinaryIO jpayne@7: if reading and writing: jpayne@7: buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment] jpayne@7: elif reading: jpayne@7: buffer = io.BufferedReader(raw, buffering) jpayne@7: else: jpayne@7: assert writing jpayne@7: buffer = io.BufferedWriter(raw, buffering) jpayne@7: if binary: jpayne@7: return buffer jpayne@7: text = io.TextIOWrapper(buffer, encoding, errors, newline) jpayne@7: text.mode = mode # type: ignore[misc] jpayne@7: return text jpayne@7: jpayne@7: def unwrap(self) -> None: jpayne@7: self._ssl_io_loop(self.sslobj.unwrap) jpayne@7: jpayne@7: def close(self) -> None: jpayne@7: self.socket.close() jpayne@7: jpayne@7: @typing.overload jpayne@7: def getpeercert( jpayne@7: self, binary_form: Literal[False] = ... jpayne@7: ) -> _TYPE_PEER_CERT_RET_DICT | None: jpayne@7: ... jpayne@7: jpayne@7: @typing.overload jpayne@7: def getpeercert(self, binary_form: Literal[True]) -> bytes | None: jpayne@7: ... jpayne@7: jpayne@7: def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: jpayne@7: return self.sslobj.getpeercert(binary_form) # type: ignore[return-value] jpayne@7: jpayne@7: def version(self) -> str | None: jpayne@7: return self.sslobj.version() jpayne@7: jpayne@7: def cipher(self) -> tuple[str, str, int] | None: jpayne@7: return self.sslobj.cipher() jpayne@7: jpayne@7: def selected_alpn_protocol(self) -> str | None: jpayne@7: return self.sslobj.selected_alpn_protocol() jpayne@7: jpayne@7: def selected_npn_protocol(self) -> str | None: jpayne@7: return self.sslobj.selected_npn_protocol() jpayne@7: jpayne@7: def shared_ciphers(self) -> list[tuple[str, str, int]] | None: jpayne@7: return self.sslobj.shared_ciphers() jpayne@7: jpayne@7: def compression(self) -> str | None: jpayne@7: return self.sslobj.compression() jpayne@7: jpayne@7: def settimeout(self, value: float | None) -> None: jpayne@7: self.socket.settimeout(value) jpayne@7: jpayne@7: def gettimeout(self) -> float | None: jpayne@7: return self.socket.gettimeout() jpayne@7: jpayne@7: def _decref_socketios(self) -> None: jpayne@7: self.socket._decref_socketios() # type: ignore[attr-defined] jpayne@7: jpayne@7: def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: jpayne@7: try: jpayne@7: return self._ssl_io_loop(self.sslobj.read, len, buffer) jpayne@7: except ssl.SSLError as e: jpayne@7: if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: jpayne@7: return 0 # eof, return 0. jpayne@7: else: jpayne@7: raise jpayne@7: jpayne@7: # func is sslobj.do_handshake or sslobj.unwrap jpayne@7: @typing.overload jpayne@7: def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: jpayne@7: ... jpayne@7: jpayne@7: # func is sslobj.write, arg1 is data jpayne@7: @typing.overload jpayne@7: def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: jpayne@7: ... jpayne@7: jpayne@7: # func is sslobj.read, arg1 is len, arg2 is buffer jpayne@7: @typing.overload jpayne@7: def _ssl_io_loop( jpayne@7: self, jpayne@7: func: typing.Callable[[int, bytearray | None], bytes], jpayne@7: arg1: int, jpayne@7: arg2: bytearray | None, jpayne@7: ) -> bytes: jpayne@7: ... jpayne@7: jpayne@7: def _ssl_io_loop( jpayne@7: self, jpayne@7: func: typing.Callable[..., _ReturnValue], jpayne@7: arg1: None | bytes | int = None, jpayne@7: arg2: bytearray | None = None, jpayne@7: ) -> _ReturnValue: jpayne@7: """Performs an I/O loop between incoming/outgoing and the socket.""" jpayne@7: should_loop = True jpayne@7: ret = None jpayne@7: jpayne@7: while should_loop: jpayne@7: errno = None jpayne@7: try: jpayne@7: if arg1 is None and arg2 is None: jpayne@7: ret = func() jpayne@7: elif arg2 is None: jpayne@7: ret = func(arg1) jpayne@7: else: jpayne@7: ret = func(arg1, arg2) jpayne@7: except ssl.SSLError as e: jpayne@7: if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): jpayne@7: # WANT_READ, and WANT_WRITE are expected, others are not. jpayne@7: raise e jpayne@7: errno = e.errno jpayne@7: jpayne@7: buf = self.outgoing.read() jpayne@7: self.socket.sendall(buf) jpayne@7: jpayne@7: if errno is None: jpayne@7: should_loop = False jpayne@7: elif errno == ssl.SSL_ERROR_WANT_READ: jpayne@7: buf = self.socket.recv(SSL_BLOCKSIZE) jpayne@7: if buf: jpayne@7: self.incoming.write(buf) jpayne@7: else: jpayne@7: self.incoming.write_eof() jpayne@7: return typing.cast(_ReturnValue, ret)