diff urllib3/util/ssltransport.py @ 7:5eb2d5e3bf22

planemo upload for repository https://toolrepo.galaxytrakr.org/view/jpayne/bioproject_to_srr_2/556cac4fb538
author jpayne
date Sun, 05 May 2024 23:32:17 -0400
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/urllib3/util/ssltransport.py	Sun May 05 23:32:17 2024 -0400
@@ -0,0 +1,280 @@
+from __future__ import annotations
+
+import io
+import socket
+import ssl
+import typing
+
+from ..exceptions import ProxySchemeUnsupported
+
+if typing.TYPE_CHECKING:
+    from typing import Literal
+
+    from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
+
+
+_SelfT = typing.TypeVar("_SelfT", bound="SSLTransport")
+_WriteBuffer = typing.Union[bytearray, memoryview]
+_ReturnValue = typing.TypeVar("_ReturnValue")
+
+SSL_BLOCKSIZE = 16384
+
+
+class SSLTransport:
+    """
+    The SSLTransport wraps an existing socket and establishes an SSL connection.
+
+    Contrary to Python's implementation of SSLSocket, it allows you to chain
+    multiple TLS connections together. It's particularly useful if you need to
+    implement TLS within TLS.
+
+    The class supports most of the socket API operations.
+    """
+
+    @staticmethod
+    def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
+        """
+        Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
+        for TLS in TLS.
+
+        The only requirement is that the ssl_context provides the 'wrap_bio'
+        methods.
+        """
+
+        if not hasattr(ssl_context, "wrap_bio"):
+            raise ProxySchemeUnsupported(
+                "TLS in TLS requires SSLContext.wrap_bio() which isn't "
+                "available on non-native SSLContext"
+            )
+
+    def __init__(
+        self,
+        socket: socket.socket,
+        ssl_context: ssl.SSLContext,
+        server_hostname: str | None = None,
+        suppress_ragged_eofs: bool = True,
+    ) -> None:
+        """
+        Create an SSLTransport around socket using the provided ssl_context.
+        """
+        self.incoming = ssl.MemoryBIO()
+        self.outgoing = ssl.MemoryBIO()
+
+        self.suppress_ragged_eofs = suppress_ragged_eofs
+        self.socket = socket
+
+        self.sslobj = ssl_context.wrap_bio(
+            self.incoming, self.outgoing, server_hostname=server_hostname
+        )
+
+        # Perform initial handshake.
+        self._ssl_io_loop(self.sslobj.do_handshake)
+
+    def __enter__(self: _SelfT) -> _SelfT:
+        return self
+
+    def __exit__(self, *_: typing.Any) -> None:
+        self.close()
+
+    def fileno(self) -> int:
+        return self.socket.fileno()
+
+    def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
+        return self._wrap_ssl_read(len, buffer)
+
+    def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
+        if flags != 0:
+            raise ValueError("non-zero flags not allowed in calls to recv")
+        return self._wrap_ssl_read(buflen)
+
+    def recv_into(
+        self,
+        buffer: _WriteBuffer,
+        nbytes: int | None = None,
+        flags: int = 0,
+    ) -> None | int | bytes:
+        if flags != 0:
+            raise ValueError("non-zero flags not allowed in calls to recv_into")
+        if nbytes is None:
+            nbytes = len(buffer)
+        return self.read(nbytes, buffer)
+
+    def sendall(self, data: bytes, flags: int = 0) -> None:
+        if flags != 0:
+            raise ValueError("non-zero flags not allowed in calls to sendall")
+        count = 0
+        with memoryview(data) as view, view.cast("B") as byte_view:
+            amount = len(byte_view)
+            while count < amount:
+                v = self.send(byte_view[count:])
+                count += v
+
+    def send(self, data: bytes, flags: int = 0) -> int:
+        if flags != 0:
+            raise ValueError("non-zero flags not allowed in calls to send")
+        return self._ssl_io_loop(self.sslobj.write, data)
+
+    def makefile(
+        self,
+        mode: str,
+        buffering: int | None = None,
+        *,
+        encoding: str | None = None,
+        errors: str | None = None,
+        newline: str | None = None,
+    ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
+        """
+        Python's httpclient uses makefile and buffered io when reading HTTP
+        messages and we need to support it.
+
+        This is unfortunately a copy and paste of socket.py makefile with small
+        changes to point to the socket directly.
+        """
+        if not set(mode) <= {"r", "w", "b"}:
+            raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
+
+        writing = "w" in mode
+        reading = "r" in mode or not writing
+        assert reading or writing
+        binary = "b" in mode
+        rawmode = ""
+        if reading:
+            rawmode += "r"
+        if writing:
+            rawmode += "w"
+        raw = socket.SocketIO(self, rawmode)  # type: ignore[arg-type]
+        self.socket._io_refs += 1  # type: ignore[attr-defined]
+        if buffering is None:
+            buffering = -1
+        if buffering < 0:
+            buffering = io.DEFAULT_BUFFER_SIZE
+        if buffering == 0:
+            if not binary:
+                raise ValueError("unbuffered streams must be binary")
+            return raw
+        buffer: typing.BinaryIO
+        if reading and writing:
+            buffer = io.BufferedRWPair(raw, raw, buffering)  # type: ignore[assignment]
+        elif reading:
+            buffer = io.BufferedReader(raw, buffering)
+        else:
+            assert writing
+            buffer = io.BufferedWriter(raw, buffering)
+        if binary:
+            return buffer
+        text = io.TextIOWrapper(buffer, encoding, errors, newline)
+        text.mode = mode  # type: ignore[misc]
+        return text
+
+    def unwrap(self) -> None:
+        self._ssl_io_loop(self.sslobj.unwrap)
+
+    def close(self) -> None:
+        self.socket.close()
+
+    @typing.overload
+    def getpeercert(
+        self, binary_form: Literal[False] = ...
+    ) -> _TYPE_PEER_CERT_RET_DICT | None:
+        ...
+
+    @typing.overload
+    def getpeercert(self, binary_form: Literal[True]) -> bytes | None:
+        ...
+
+    def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
+        return self.sslobj.getpeercert(binary_form)  # type: ignore[return-value]
+
+    def version(self) -> str | None:
+        return self.sslobj.version()
+
+    def cipher(self) -> tuple[str, str, int] | None:
+        return self.sslobj.cipher()
+
+    def selected_alpn_protocol(self) -> str | None:
+        return self.sslobj.selected_alpn_protocol()
+
+    def selected_npn_protocol(self) -> str | None:
+        return self.sslobj.selected_npn_protocol()
+
+    def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
+        return self.sslobj.shared_ciphers()
+
+    def compression(self) -> str | None:
+        return self.sslobj.compression()
+
+    def settimeout(self, value: float | None) -> None:
+        self.socket.settimeout(value)
+
+    def gettimeout(self) -> float | None:
+        return self.socket.gettimeout()
+
+    def _decref_socketios(self) -> None:
+        self.socket._decref_socketios()  # type: ignore[attr-defined]
+
+    def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
+        try:
+            return self._ssl_io_loop(self.sslobj.read, len, buffer)
+        except ssl.SSLError as e:
+            if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
+                return 0  # eof, return 0.
+            else:
+                raise
+
+    # func is sslobj.do_handshake or sslobj.unwrap
+    @typing.overload
+    def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None:
+        ...
+
+    # func is sslobj.write, arg1 is data
+    @typing.overload
+    def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int:
+        ...
+
+    # func is sslobj.read, arg1 is len, arg2 is buffer
+    @typing.overload
+    def _ssl_io_loop(
+        self,
+        func: typing.Callable[[int, bytearray | None], bytes],
+        arg1: int,
+        arg2: bytearray | None,
+    ) -> bytes:
+        ...
+
+    def _ssl_io_loop(
+        self,
+        func: typing.Callable[..., _ReturnValue],
+        arg1: None | bytes | int = None,
+        arg2: bytearray | None = None,
+    ) -> _ReturnValue:
+        """Performs an I/O loop between incoming/outgoing and the socket."""
+        should_loop = True
+        ret = None
+
+        while should_loop:
+            errno = None
+            try:
+                if arg1 is None and arg2 is None:
+                    ret = func()
+                elif arg2 is None:
+                    ret = func(arg1)
+                else:
+                    ret = func(arg1, arg2)
+            except ssl.SSLError as e:
+                if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
+                    # WANT_READ, and WANT_WRITE are expected, others are not.
+                    raise e
+                errno = e.errno
+
+            buf = self.outgoing.read()
+            self.socket.sendall(buf)
+
+            if errno is None:
+                should_loop = False
+            elif errno == ssl.SSL_ERROR_WANT_READ:
+                buf = self.socket.recv(SSL_BLOCKSIZE)
+                if buf:
+                    self.incoming.write(buf)
+                else:
+                    self.incoming.write_eof()
+        return typing.cast(_ReturnValue, ret)