comparison urllib3/util/ssltransport.py @ 7:5eb2d5e3bf22

planemo upload for repository https://toolrepo.galaxytrakr.org/view/jpayne/bioproject_to_srr_2/556cac4fb538
author jpayne
date Sun, 05 May 2024 23:32:17 -0400
parents
children
comparison
equal deleted inserted replaced
6:b2745907b1eb 7:5eb2d5e3bf22
1 from __future__ import annotations
2
3 import io
4 import socket
5 import ssl
6 import typing
7
8 from ..exceptions import ProxySchemeUnsupported
9
10 if typing.TYPE_CHECKING:
11 from typing import Literal
12
13 from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
14
15
16 _SelfT = typing.TypeVar("_SelfT", bound="SSLTransport")
17 _WriteBuffer = typing.Union[bytearray, memoryview]
18 _ReturnValue = typing.TypeVar("_ReturnValue")
19
20 SSL_BLOCKSIZE = 16384
21
22
23 class SSLTransport:
24 """
25 The SSLTransport wraps an existing socket and establishes an SSL connection.
26
27 Contrary to Python's implementation of SSLSocket, it allows you to chain
28 multiple TLS connections together. It's particularly useful if you need to
29 implement TLS within TLS.
30
31 The class supports most of the socket API operations.
32 """
33
34 @staticmethod
35 def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
36 """
37 Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
38 for TLS in TLS.
39
40 The only requirement is that the ssl_context provides the 'wrap_bio'
41 methods.
42 """
43
44 if not hasattr(ssl_context, "wrap_bio"):
45 raise ProxySchemeUnsupported(
46 "TLS in TLS requires SSLContext.wrap_bio() which isn't "
47 "available on non-native SSLContext"
48 )
49
50 def __init__(
51 self,
52 socket: socket.socket,
53 ssl_context: ssl.SSLContext,
54 server_hostname: str | None = None,
55 suppress_ragged_eofs: bool = True,
56 ) -> None:
57 """
58 Create an SSLTransport around socket using the provided ssl_context.
59 """
60 self.incoming = ssl.MemoryBIO()
61 self.outgoing = ssl.MemoryBIO()
62
63 self.suppress_ragged_eofs = suppress_ragged_eofs
64 self.socket = socket
65
66 self.sslobj = ssl_context.wrap_bio(
67 self.incoming, self.outgoing, server_hostname=server_hostname
68 )
69
70 # Perform initial handshake.
71 self._ssl_io_loop(self.sslobj.do_handshake)
72
73 def __enter__(self: _SelfT) -> _SelfT:
74 return self
75
76 def __exit__(self, *_: typing.Any) -> None:
77 self.close()
78
79 def fileno(self) -> int:
80 return self.socket.fileno()
81
82 def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
83 return self._wrap_ssl_read(len, buffer)
84
85 def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
86 if flags != 0:
87 raise ValueError("non-zero flags not allowed in calls to recv")
88 return self._wrap_ssl_read(buflen)
89
90 def recv_into(
91 self,
92 buffer: _WriteBuffer,
93 nbytes: int | None = None,
94 flags: int = 0,
95 ) -> None | int | bytes:
96 if flags != 0:
97 raise ValueError("non-zero flags not allowed in calls to recv_into")
98 if nbytes is None:
99 nbytes = len(buffer)
100 return self.read(nbytes, buffer)
101
102 def sendall(self, data: bytes, flags: int = 0) -> None:
103 if flags != 0:
104 raise ValueError("non-zero flags not allowed in calls to sendall")
105 count = 0
106 with memoryview(data) as view, view.cast("B") as byte_view:
107 amount = len(byte_view)
108 while count < amount:
109 v = self.send(byte_view[count:])
110 count += v
111
112 def send(self, data: bytes, flags: int = 0) -> int:
113 if flags != 0:
114 raise ValueError("non-zero flags not allowed in calls to send")
115 return self._ssl_io_loop(self.sslobj.write, data)
116
117 def makefile(
118 self,
119 mode: str,
120 buffering: int | None = None,
121 *,
122 encoding: str | None = None,
123 errors: str | None = None,
124 newline: str | None = None,
125 ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
126 """
127 Python's httpclient uses makefile and buffered io when reading HTTP
128 messages and we need to support it.
129
130 This is unfortunately a copy and paste of socket.py makefile with small
131 changes to point to the socket directly.
132 """
133 if not set(mode) <= {"r", "w", "b"}:
134 raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
135
136 writing = "w" in mode
137 reading = "r" in mode or not writing
138 assert reading or writing
139 binary = "b" in mode
140 rawmode = ""
141 if reading:
142 rawmode += "r"
143 if writing:
144 rawmode += "w"
145 raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type]
146 self.socket._io_refs += 1 # type: ignore[attr-defined]
147 if buffering is None:
148 buffering = -1
149 if buffering < 0:
150 buffering = io.DEFAULT_BUFFER_SIZE
151 if buffering == 0:
152 if not binary:
153 raise ValueError("unbuffered streams must be binary")
154 return raw
155 buffer: typing.BinaryIO
156 if reading and writing:
157 buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment]
158 elif reading:
159 buffer = io.BufferedReader(raw, buffering)
160 else:
161 assert writing
162 buffer = io.BufferedWriter(raw, buffering)
163 if binary:
164 return buffer
165 text = io.TextIOWrapper(buffer, encoding, errors, newline)
166 text.mode = mode # type: ignore[misc]
167 return text
168
169 def unwrap(self) -> None:
170 self._ssl_io_loop(self.sslobj.unwrap)
171
172 def close(self) -> None:
173 self.socket.close()
174
175 @typing.overload
176 def getpeercert(
177 self, binary_form: Literal[False] = ...
178 ) -> _TYPE_PEER_CERT_RET_DICT | None:
179 ...
180
181 @typing.overload
182 def getpeercert(self, binary_form: Literal[True]) -> bytes | None:
183 ...
184
185 def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
186 return self.sslobj.getpeercert(binary_form) # type: ignore[return-value]
187
188 def version(self) -> str | None:
189 return self.sslobj.version()
190
191 def cipher(self) -> tuple[str, str, int] | None:
192 return self.sslobj.cipher()
193
194 def selected_alpn_protocol(self) -> str | None:
195 return self.sslobj.selected_alpn_protocol()
196
197 def selected_npn_protocol(self) -> str | None:
198 return self.sslobj.selected_npn_protocol()
199
200 def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
201 return self.sslobj.shared_ciphers()
202
203 def compression(self) -> str | None:
204 return self.sslobj.compression()
205
206 def settimeout(self, value: float | None) -> None:
207 self.socket.settimeout(value)
208
209 def gettimeout(self) -> float | None:
210 return self.socket.gettimeout()
211
212 def _decref_socketios(self) -> None:
213 self.socket._decref_socketios() # type: ignore[attr-defined]
214
215 def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
216 try:
217 return self._ssl_io_loop(self.sslobj.read, len, buffer)
218 except ssl.SSLError as e:
219 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
220 return 0 # eof, return 0.
221 else:
222 raise
223
224 # func is sslobj.do_handshake or sslobj.unwrap
225 @typing.overload
226 def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None:
227 ...
228
229 # func is sslobj.write, arg1 is data
230 @typing.overload
231 def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int:
232 ...
233
234 # func is sslobj.read, arg1 is len, arg2 is buffer
235 @typing.overload
236 def _ssl_io_loop(
237 self,
238 func: typing.Callable[[int, bytearray | None], bytes],
239 arg1: int,
240 arg2: bytearray | None,
241 ) -> bytes:
242 ...
243
244 def _ssl_io_loop(
245 self,
246 func: typing.Callable[..., _ReturnValue],
247 arg1: None | bytes | int = None,
248 arg2: bytearray | None = None,
249 ) -> _ReturnValue:
250 """Performs an I/O loop between incoming/outgoing and the socket."""
251 should_loop = True
252 ret = None
253
254 while should_loop:
255 errno = None
256 try:
257 if arg1 is None and arg2 is None:
258 ret = func()
259 elif arg2 is None:
260 ret = func(arg1)
261 else:
262 ret = func(arg1, arg2)
263 except ssl.SSLError as e:
264 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
265 # WANT_READ, and WANT_WRITE are expected, others are not.
266 raise e
267 errno = e.errno
268
269 buf = self.outgoing.read()
270 self.socket.sendall(buf)
271
272 if errno is None:
273 should_loop = False
274 elif errno == ssl.SSL_ERROR_WANT_READ:
275 buf = self.socket.recv(SSL_BLOCKSIZE)
276 if buf:
277 self.incoming.write(buf)
278 else:
279 self.incoming.write_eof()
280 return typing.cast(_ReturnValue, ret)