Mercurial > repos > jpayne > bioproject_to_srr_2
comparison urllib3/util/ssltransport.py @ 7:5eb2d5e3bf22
planemo upload for repository https://toolrepo.galaxytrakr.org/view/jpayne/bioproject_to_srr_2/556cac4fb538
author | jpayne |
---|---|
date | Sun, 05 May 2024 23:32:17 -0400 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
6:b2745907b1eb | 7:5eb2d5e3bf22 |
---|---|
1 from __future__ import annotations | |
2 | |
3 import io | |
4 import socket | |
5 import ssl | |
6 import typing | |
7 | |
8 from ..exceptions import ProxySchemeUnsupported | |
9 | |
10 if typing.TYPE_CHECKING: | |
11 from typing import Literal | |
12 | |
13 from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT | |
14 | |
15 | |
16 _SelfT = typing.TypeVar("_SelfT", bound="SSLTransport") | |
17 _WriteBuffer = typing.Union[bytearray, memoryview] | |
18 _ReturnValue = typing.TypeVar("_ReturnValue") | |
19 | |
20 SSL_BLOCKSIZE = 16384 | |
21 | |
22 | |
23 class SSLTransport: | |
24 """ | |
25 The SSLTransport wraps an existing socket and establishes an SSL connection. | |
26 | |
27 Contrary to Python's implementation of SSLSocket, it allows you to chain | |
28 multiple TLS connections together. It's particularly useful if you need to | |
29 implement TLS within TLS. | |
30 | |
31 The class supports most of the socket API operations. | |
32 """ | |
33 | |
34 @staticmethod | |
35 def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: | |
36 """ | |
37 Raises a ProxySchemeUnsupported if the provided ssl_context can't be used | |
38 for TLS in TLS. | |
39 | |
40 The only requirement is that the ssl_context provides the 'wrap_bio' | |
41 methods. | |
42 """ | |
43 | |
44 if not hasattr(ssl_context, "wrap_bio"): | |
45 raise ProxySchemeUnsupported( | |
46 "TLS in TLS requires SSLContext.wrap_bio() which isn't " | |
47 "available on non-native SSLContext" | |
48 ) | |
49 | |
50 def __init__( | |
51 self, | |
52 socket: socket.socket, | |
53 ssl_context: ssl.SSLContext, | |
54 server_hostname: str | None = None, | |
55 suppress_ragged_eofs: bool = True, | |
56 ) -> None: | |
57 """ | |
58 Create an SSLTransport around socket using the provided ssl_context. | |
59 """ | |
60 self.incoming = ssl.MemoryBIO() | |
61 self.outgoing = ssl.MemoryBIO() | |
62 | |
63 self.suppress_ragged_eofs = suppress_ragged_eofs | |
64 self.socket = socket | |
65 | |
66 self.sslobj = ssl_context.wrap_bio( | |
67 self.incoming, self.outgoing, server_hostname=server_hostname | |
68 ) | |
69 | |
70 # Perform initial handshake. | |
71 self._ssl_io_loop(self.sslobj.do_handshake) | |
72 | |
73 def __enter__(self: _SelfT) -> _SelfT: | |
74 return self | |
75 | |
76 def __exit__(self, *_: typing.Any) -> None: | |
77 self.close() | |
78 | |
79 def fileno(self) -> int: | |
80 return self.socket.fileno() | |
81 | |
82 def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: | |
83 return self._wrap_ssl_read(len, buffer) | |
84 | |
85 def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: | |
86 if flags != 0: | |
87 raise ValueError("non-zero flags not allowed in calls to recv") | |
88 return self._wrap_ssl_read(buflen) | |
89 | |
90 def recv_into( | |
91 self, | |
92 buffer: _WriteBuffer, | |
93 nbytes: int | None = None, | |
94 flags: int = 0, | |
95 ) -> None | int | bytes: | |
96 if flags != 0: | |
97 raise ValueError("non-zero flags not allowed in calls to recv_into") | |
98 if nbytes is None: | |
99 nbytes = len(buffer) | |
100 return self.read(nbytes, buffer) | |
101 | |
102 def sendall(self, data: bytes, flags: int = 0) -> None: | |
103 if flags != 0: | |
104 raise ValueError("non-zero flags not allowed in calls to sendall") | |
105 count = 0 | |
106 with memoryview(data) as view, view.cast("B") as byte_view: | |
107 amount = len(byte_view) | |
108 while count < amount: | |
109 v = self.send(byte_view[count:]) | |
110 count += v | |
111 | |
112 def send(self, data: bytes, flags: int = 0) -> int: | |
113 if flags != 0: | |
114 raise ValueError("non-zero flags not allowed in calls to send") | |
115 return self._ssl_io_loop(self.sslobj.write, data) | |
116 | |
117 def makefile( | |
118 self, | |
119 mode: str, | |
120 buffering: int | None = None, | |
121 *, | |
122 encoding: str | None = None, | |
123 errors: str | None = None, | |
124 newline: str | None = None, | |
125 ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: | |
126 """ | |
127 Python's httpclient uses makefile and buffered io when reading HTTP | |
128 messages and we need to support it. | |
129 | |
130 This is unfortunately a copy and paste of socket.py makefile with small | |
131 changes to point to the socket directly. | |
132 """ | |
133 if not set(mode) <= {"r", "w", "b"}: | |
134 raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") | |
135 | |
136 writing = "w" in mode | |
137 reading = "r" in mode or not writing | |
138 assert reading or writing | |
139 binary = "b" in mode | |
140 rawmode = "" | |
141 if reading: | |
142 rawmode += "r" | |
143 if writing: | |
144 rawmode += "w" | |
145 raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type] | |
146 self.socket._io_refs += 1 # type: ignore[attr-defined] | |
147 if buffering is None: | |
148 buffering = -1 | |
149 if buffering < 0: | |
150 buffering = io.DEFAULT_BUFFER_SIZE | |
151 if buffering == 0: | |
152 if not binary: | |
153 raise ValueError("unbuffered streams must be binary") | |
154 return raw | |
155 buffer: typing.BinaryIO | |
156 if reading and writing: | |
157 buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment] | |
158 elif reading: | |
159 buffer = io.BufferedReader(raw, buffering) | |
160 else: | |
161 assert writing | |
162 buffer = io.BufferedWriter(raw, buffering) | |
163 if binary: | |
164 return buffer | |
165 text = io.TextIOWrapper(buffer, encoding, errors, newline) | |
166 text.mode = mode # type: ignore[misc] | |
167 return text | |
168 | |
169 def unwrap(self) -> None: | |
170 self._ssl_io_loop(self.sslobj.unwrap) | |
171 | |
172 def close(self) -> None: | |
173 self.socket.close() | |
174 | |
175 @typing.overload | |
176 def getpeercert( | |
177 self, binary_form: Literal[False] = ... | |
178 ) -> _TYPE_PEER_CERT_RET_DICT | None: | |
179 ... | |
180 | |
181 @typing.overload | |
182 def getpeercert(self, binary_form: Literal[True]) -> bytes | None: | |
183 ... | |
184 | |
185 def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: | |
186 return self.sslobj.getpeercert(binary_form) # type: ignore[return-value] | |
187 | |
188 def version(self) -> str | None: | |
189 return self.sslobj.version() | |
190 | |
191 def cipher(self) -> tuple[str, str, int] | None: | |
192 return self.sslobj.cipher() | |
193 | |
194 def selected_alpn_protocol(self) -> str | None: | |
195 return self.sslobj.selected_alpn_protocol() | |
196 | |
197 def selected_npn_protocol(self) -> str | None: | |
198 return self.sslobj.selected_npn_protocol() | |
199 | |
200 def shared_ciphers(self) -> list[tuple[str, str, int]] | None: | |
201 return self.sslobj.shared_ciphers() | |
202 | |
203 def compression(self) -> str | None: | |
204 return self.sslobj.compression() | |
205 | |
206 def settimeout(self, value: float | None) -> None: | |
207 self.socket.settimeout(value) | |
208 | |
209 def gettimeout(self) -> float | None: | |
210 return self.socket.gettimeout() | |
211 | |
212 def _decref_socketios(self) -> None: | |
213 self.socket._decref_socketios() # type: ignore[attr-defined] | |
214 | |
215 def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: | |
216 try: | |
217 return self._ssl_io_loop(self.sslobj.read, len, buffer) | |
218 except ssl.SSLError as e: | |
219 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: | |
220 return 0 # eof, return 0. | |
221 else: | |
222 raise | |
223 | |
224 # func is sslobj.do_handshake or sslobj.unwrap | |
225 @typing.overload | |
226 def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: | |
227 ... | |
228 | |
229 # func is sslobj.write, arg1 is data | |
230 @typing.overload | |
231 def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: | |
232 ... | |
233 | |
234 # func is sslobj.read, arg1 is len, arg2 is buffer | |
235 @typing.overload | |
236 def _ssl_io_loop( | |
237 self, | |
238 func: typing.Callable[[int, bytearray | None], bytes], | |
239 arg1: int, | |
240 arg2: bytearray | None, | |
241 ) -> bytes: | |
242 ... | |
243 | |
244 def _ssl_io_loop( | |
245 self, | |
246 func: typing.Callable[..., _ReturnValue], | |
247 arg1: None | bytes | int = None, | |
248 arg2: bytearray | None = None, | |
249 ) -> _ReturnValue: | |
250 """Performs an I/O loop between incoming/outgoing and the socket.""" | |
251 should_loop = True | |
252 ret = None | |
253 | |
254 while should_loop: | |
255 errno = None | |
256 try: | |
257 if arg1 is None and arg2 is None: | |
258 ret = func() | |
259 elif arg2 is None: | |
260 ret = func(arg1) | |
261 else: | |
262 ret = func(arg1, arg2) | |
263 except ssl.SSLError as e: | |
264 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): | |
265 # WANT_READ, and WANT_WRITE are expected, others are not. | |
266 raise e | |
267 errno = e.errno | |
268 | |
269 buf = self.outgoing.read() | |
270 self.socket.sendall(buf) | |
271 | |
272 if errno is None: | |
273 should_loop = False | |
274 elif errno == ssl.SSL_ERROR_WANT_READ: | |
275 buf = self.socket.recv(SSL_BLOCKSIZE) | |
276 if buf: | |
277 self.incoming.write(buf) | |
278 else: | |
279 self.incoming.write_eof() | |
280 return typing.cast(_ReturnValue, ret) |