jpayne@7
|
1 from __future__ import annotations
|
jpayne@7
|
2
|
jpayne@7
|
3 import io
|
jpayne@7
|
4 import socket
|
jpayne@7
|
5 import ssl
|
jpayne@7
|
6 import typing
|
jpayne@7
|
7
|
jpayne@7
|
8 from ..exceptions import ProxySchemeUnsupported
|
jpayne@7
|
9
|
jpayne@7
|
10 if typing.TYPE_CHECKING:
|
jpayne@7
|
11 from typing import Literal
|
jpayne@7
|
12
|
jpayne@7
|
13 from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
|
jpayne@7
|
14
|
jpayne@7
|
15
|
jpayne@7
|
16 _SelfT = typing.TypeVar("_SelfT", bound="SSLTransport")
|
jpayne@7
|
17 _WriteBuffer = typing.Union[bytearray, memoryview]
|
jpayne@7
|
18 _ReturnValue = typing.TypeVar("_ReturnValue")
|
jpayne@7
|
19
|
jpayne@7
|
20 SSL_BLOCKSIZE = 16384
|
jpayne@7
|
21
|
jpayne@7
|
22
|
jpayne@7
|
23 class SSLTransport:
|
jpayne@7
|
24 """
|
jpayne@7
|
25 The SSLTransport wraps an existing socket and establishes an SSL connection.
|
jpayne@7
|
26
|
jpayne@7
|
27 Contrary to Python's implementation of SSLSocket, it allows you to chain
|
jpayne@7
|
28 multiple TLS connections together. It's particularly useful if you need to
|
jpayne@7
|
29 implement TLS within TLS.
|
jpayne@7
|
30
|
jpayne@7
|
31 The class supports most of the socket API operations.
|
jpayne@7
|
32 """
|
jpayne@7
|
33
|
jpayne@7
|
34 @staticmethod
|
jpayne@7
|
35 def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
|
jpayne@7
|
36 """
|
jpayne@7
|
37 Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
|
jpayne@7
|
38 for TLS in TLS.
|
jpayne@7
|
39
|
jpayne@7
|
40 The only requirement is that the ssl_context provides the 'wrap_bio'
|
jpayne@7
|
41 methods.
|
jpayne@7
|
42 """
|
jpayne@7
|
43
|
jpayne@7
|
44 if not hasattr(ssl_context, "wrap_bio"):
|
jpayne@7
|
45 raise ProxySchemeUnsupported(
|
jpayne@7
|
46 "TLS in TLS requires SSLContext.wrap_bio() which isn't "
|
jpayne@7
|
47 "available on non-native SSLContext"
|
jpayne@7
|
48 )
|
jpayne@7
|
49
|
jpayne@7
|
50 def __init__(
|
jpayne@7
|
51 self,
|
jpayne@7
|
52 socket: socket.socket,
|
jpayne@7
|
53 ssl_context: ssl.SSLContext,
|
jpayne@7
|
54 server_hostname: str | None = None,
|
jpayne@7
|
55 suppress_ragged_eofs: bool = True,
|
jpayne@7
|
56 ) -> None:
|
jpayne@7
|
57 """
|
jpayne@7
|
58 Create an SSLTransport around socket using the provided ssl_context.
|
jpayne@7
|
59 """
|
jpayne@7
|
60 self.incoming = ssl.MemoryBIO()
|
jpayne@7
|
61 self.outgoing = ssl.MemoryBIO()
|
jpayne@7
|
62
|
jpayne@7
|
63 self.suppress_ragged_eofs = suppress_ragged_eofs
|
jpayne@7
|
64 self.socket = socket
|
jpayne@7
|
65
|
jpayne@7
|
66 self.sslobj = ssl_context.wrap_bio(
|
jpayne@7
|
67 self.incoming, self.outgoing, server_hostname=server_hostname
|
jpayne@7
|
68 )
|
jpayne@7
|
69
|
jpayne@7
|
70 # Perform initial handshake.
|
jpayne@7
|
71 self._ssl_io_loop(self.sslobj.do_handshake)
|
jpayne@7
|
72
|
jpayne@7
|
73 def __enter__(self: _SelfT) -> _SelfT:
|
jpayne@7
|
74 return self
|
jpayne@7
|
75
|
jpayne@7
|
76 def __exit__(self, *_: typing.Any) -> None:
|
jpayne@7
|
77 self.close()
|
jpayne@7
|
78
|
jpayne@7
|
79 def fileno(self) -> int:
|
jpayne@7
|
80 return self.socket.fileno()
|
jpayne@7
|
81
|
jpayne@7
|
82 def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
|
jpayne@7
|
83 return self._wrap_ssl_read(len, buffer)
|
jpayne@7
|
84
|
jpayne@7
|
85 def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
|
jpayne@7
|
86 if flags != 0:
|
jpayne@7
|
87 raise ValueError("non-zero flags not allowed in calls to recv")
|
jpayne@7
|
88 return self._wrap_ssl_read(buflen)
|
jpayne@7
|
89
|
jpayne@7
|
90 def recv_into(
|
jpayne@7
|
91 self,
|
jpayne@7
|
92 buffer: _WriteBuffer,
|
jpayne@7
|
93 nbytes: int | None = None,
|
jpayne@7
|
94 flags: int = 0,
|
jpayne@7
|
95 ) -> None | int | bytes:
|
jpayne@7
|
96 if flags != 0:
|
jpayne@7
|
97 raise ValueError("non-zero flags not allowed in calls to recv_into")
|
jpayne@7
|
98 if nbytes is None:
|
jpayne@7
|
99 nbytes = len(buffer)
|
jpayne@7
|
100 return self.read(nbytes, buffer)
|
jpayne@7
|
101
|
jpayne@7
|
102 def sendall(self, data: bytes, flags: int = 0) -> None:
|
jpayne@7
|
103 if flags != 0:
|
jpayne@7
|
104 raise ValueError("non-zero flags not allowed in calls to sendall")
|
jpayne@7
|
105 count = 0
|
jpayne@7
|
106 with memoryview(data) as view, view.cast("B") as byte_view:
|
jpayne@7
|
107 amount = len(byte_view)
|
jpayne@7
|
108 while count < amount:
|
jpayne@7
|
109 v = self.send(byte_view[count:])
|
jpayne@7
|
110 count += v
|
jpayne@7
|
111
|
jpayne@7
|
112 def send(self, data: bytes, flags: int = 0) -> int:
|
jpayne@7
|
113 if flags != 0:
|
jpayne@7
|
114 raise ValueError("non-zero flags not allowed in calls to send")
|
jpayne@7
|
115 return self._ssl_io_loop(self.sslobj.write, data)
|
jpayne@7
|
116
|
jpayne@7
|
117 def makefile(
|
jpayne@7
|
118 self,
|
jpayne@7
|
119 mode: str,
|
jpayne@7
|
120 buffering: int | None = None,
|
jpayne@7
|
121 *,
|
jpayne@7
|
122 encoding: str | None = None,
|
jpayne@7
|
123 errors: str | None = None,
|
jpayne@7
|
124 newline: str | None = None,
|
jpayne@7
|
125 ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
|
jpayne@7
|
126 """
|
jpayne@7
|
127 Python's httpclient uses makefile and buffered io when reading HTTP
|
jpayne@7
|
128 messages and we need to support it.
|
jpayne@7
|
129
|
jpayne@7
|
130 This is unfortunately a copy and paste of socket.py makefile with small
|
jpayne@7
|
131 changes to point to the socket directly.
|
jpayne@7
|
132 """
|
jpayne@7
|
133 if not set(mode) <= {"r", "w", "b"}:
|
jpayne@7
|
134 raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
|
jpayne@7
|
135
|
jpayne@7
|
136 writing = "w" in mode
|
jpayne@7
|
137 reading = "r" in mode or not writing
|
jpayne@7
|
138 assert reading or writing
|
jpayne@7
|
139 binary = "b" in mode
|
jpayne@7
|
140 rawmode = ""
|
jpayne@7
|
141 if reading:
|
jpayne@7
|
142 rawmode += "r"
|
jpayne@7
|
143 if writing:
|
jpayne@7
|
144 rawmode += "w"
|
jpayne@7
|
145 raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type]
|
jpayne@7
|
146 self.socket._io_refs += 1 # type: ignore[attr-defined]
|
jpayne@7
|
147 if buffering is None:
|
jpayne@7
|
148 buffering = -1
|
jpayne@7
|
149 if buffering < 0:
|
jpayne@7
|
150 buffering = io.DEFAULT_BUFFER_SIZE
|
jpayne@7
|
151 if buffering == 0:
|
jpayne@7
|
152 if not binary:
|
jpayne@7
|
153 raise ValueError("unbuffered streams must be binary")
|
jpayne@7
|
154 return raw
|
jpayne@7
|
155 buffer: typing.BinaryIO
|
jpayne@7
|
156 if reading and writing:
|
jpayne@7
|
157 buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment]
|
jpayne@7
|
158 elif reading:
|
jpayne@7
|
159 buffer = io.BufferedReader(raw, buffering)
|
jpayne@7
|
160 else:
|
jpayne@7
|
161 assert writing
|
jpayne@7
|
162 buffer = io.BufferedWriter(raw, buffering)
|
jpayne@7
|
163 if binary:
|
jpayne@7
|
164 return buffer
|
jpayne@7
|
165 text = io.TextIOWrapper(buffer, encoding, errors, newline)
|
jpayne@7
|
166 text.mode = mode # type: ignore[misc]
|
jpayne@7
|
167 return text
|
jpayne@7
|
168
|
jpayne@7
|
169 def unwrap(self) -> None:
|
jpayne@7
|
170 self._ssl_io_loop(self.sslobj.unwrap)
|
jpayne@7
|
171
|
jpayne@7
|
172 def close(self) -> None:
|
jpayne@7
|
173 self.socket.close()
|
jpayne@7
|
174
|
jpayne@7
|
175 @typing.overload
|
jpayne@7
|
176 def getpeercert(
|
jpayne@7
|
177 self, binary_form: Literal[False] = ...
|
jpayne@7
|
178 ) -> _TYPE_PEER_CERT_RET_DICT | None:
|
jpayne@7
|
179 ...
|
jpayne@7
|
180
|
jpayne@7
|
181 @typing.overload
|
jpayne@7
|
182 def getpeercert(self, binary_form: Literal[True]) -> bytes | None:
|
jpayne@7
|
183 ...
|
jpayne@7
|
184
|
jpayne@7
|
185 def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
|
jpayne@7
|
186 return self.sslobj.getpeercert(binary_form) # type: ignore[return-value]
|
jpayne@7
|
187
|
jpayne@7
|
188 def version(self) -> str | None:
|
jpayne@7
|
189 return self.sslobj.version()
|
jpayne@7
|
190
|
jpayne@7
|
191 def cipher(self) -> tuple[str, str, int] | None:
|
jpayne@7
|
192 return self.sslobj.cipher()
|
jpayne@7
|
193
|
jpayne@7
|
194 def selected_alpn_protocol(self) -> str | None:
|
jpayne@7
|
195 return self.sslobj.selected_alpn_protocol()
|
jpayne@7
|
196
|
jpayne@7
|
197 def selected_npn_protocol(self) -> str | None:
|
jpayne@7
|
198 return self.sslobj.selected_npn_protocol()
|
jpayne@7
|
199
|
jpayne@7
|
200 def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
|
jpayne@7
|
201 return self.sslobj.shared_ciphers()
|
jpayne@7
|
202
|
jpayne@7
|
203 def compression(self) -> str | None:
|
jpayne@7
|
204 return self.sslobj.compression()
|
jpayne@7
|
205
|
jpayne@7
|
206 def settimeout(self, value: float | None) -> None:
|
jpayne@7
|
207 self.socket.settimeout(value)
|
jpayne@7
|
208
|
jpayne@7
|
209 def gettimeout(self) -> float | None:
|
jpayne@7
|
210 return self.socket.gettimeout()
|
jpayne@7
|
211
|
jpayne@7
|
212 def _decref_socketios(self) -> None:
|
jpayne@7
|
213 self.socket._decref_socketios() # type: ignore[attr-defined]
|
jpayne@7
|
214
|
jpayne@7
|
215 def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
|
jpayne@7
|
216 try:
|
jpayne@7
|
217 return self._ssl_io_loop(self.sslobj.read, len, buffer)
|
jpayne@7
|
218 except ssl.SSLError as e:
|
jpayne@7
|
219 if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
|
jpayne@7
|
220 return 0 # eof, return 0.
|
jpayne@7
|
221 else:
|
jpayne@7
|
222 raise
|
jpayne@7
|
223
|
jpayne@7
|
224 # func is sslobj.do_handshake or sslobj.unwrap
|
jpayne@7
|
225 @typing.overload
|
jpayne@7
|
226 def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None:
|
jpayne@7
|
227 ...
|
jpayne@7
|
228
|
jpayne@7
|
229 # func is sslobj.write, arg1 is data
|
jpayne@7
|
230 @typing.overload
|
jpayne@7
|
231 def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int:
|
jpayne@7
|
232 ...
|
jpayne@7
|
233
|
jpayne@7
|
234 # func is sslobj.read, arg1 is len, arg2 is buffer
|
jpayne@7
|
235 @typing.overload
|
jpayne@7
|
236 def _ssl_io_loop(
|
jpayne@7
|
237 self,
|
jpayne@7
|
238 func: typing.Callable[[int, bytearray | None], bytes],
|
jpayne@7
|
239 arg1: int,
|
jpayne@7
|
240 arg2: bytearray | None,
|
jpayne@7
|
241 ) -> bytes:
|
jpayne@7
|
242 ...
|
jpayne@7
|
243
|
jpayne@7
|
244 def _ssl_io_loop(
|
jpayne@7
|
245 self,
|
jpayne@7
|
246 func: typing.Callable[..., _ReturnValue],
|
jpayne@7
|
247 arg1: None | bytes | int = None,
|
jpayne@7
|
248 arg2: bytearray | None = None,
|
jpayne@7
|
249 ) -> _ReturnValue:
|
jpayne@7
|
250 """Performs an I/O loop between incoming/outgoing and the socket."""
|
jpayne@7
|
251 should_loop = True
|
jpayne@7
|
252 ret = None
|
jpayne@7
|
253
|
jpayne@7
|
254 while should_loop:
|
jpayne@7
|
255 errno = None
|
jpayne@7
|
256 try:
|
jpayne@7
|
257 if arg1 is None and arg2 is None:
|
jpayne@7
|
258 ret = func()
|
jpayne@7
|
259 elif arg2 is None:
|
jpayne@7
|
260 ret = func(arg1)
|
jpayne@7
|
261 else:
|
jpayne@7
|
262 ret = func(arg1, arg2)
|
jpayne@7
|
263 except ssl.SSLError as e:
|
jpayne@7
|
264 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
|
jpayne@7
|
265 # WANT_READ, and WANT_WRITE are expected, others are not.
|
jpayne@7
|
266 raise e
|
jpayne@7
|
267 errno = e.errno
|
jpayne@7
|
268
|
jpayne@7
|
269 buf = self.outgoing.read()
|
jpayne@7
|
270 self.socket.sendall(buf)
|
jpayne@7
|
271
|
jpayne@7
|
272 if errno is None:
|
jpayne@7
|
273 should_loop = False
|
jpayne@7
|
274 elif errno == ssl.SSL_ERROR_WANT_READ:
|
jpayne@7
|
275 buf = self.socket.recv(SSL_BLOCKSIZE)
|
jpayne@7
|
276 if buf:
|
jpayne@7
|
277 self.incoming.write(buf)
|
jpayne@7
|
278 else:
|
jpayne@7
|
279 self.incoming.write_eof()
|
jpayne@7
|
280 return typing.cast(_ReturnValue, ret)
|