jpayne@7
|
1 from __future__ import annotations
|
jpayne@7
|
2
|
jpayne@7
|
3 import threading
|
jpayne@7
|
4 import types
|
jpayne@7
|
5 import typing
|
jpayne@7
|
6
|
jpayne@7
|
7 import h2.config # type: ignore[import-untyped]
|
jpayne@7
|
8 import h2.connection # type: ignore[import-untyped]
|
jpayne@7
|
9 import h2.events # type: ignore[import-untyped]
|
jpayne@7
|
10
|
jpayne@7
|
11 import urllib3.connection
|
jpayne@7
|
12 import urllib3.util.ssl_
|
jpayne@7
|
13 from urllib3.response import BaseHTTPResponse
|
jpayne@7
|
14
|
jpayne@7
|
15 from ._collections import HTTPHeaderDict
|
jpayne@7
|
16 from .connection import HTTPSConnection
|
jpayne@7
|
17 from .connectionpool import HTTPSConnectionPool
|
jpayne@7
|
18
|
jpayne@7
|
19 orig_HTTPSConnection = HTTPSConnection
|
jpayne@7
|
20
|
jpayne@7
|
21 T = typing.TypeVar("T")
|
jpayne@7
|
22
|
jpayne@7
|
23
|
jpayne@7
|
24 class _LockedObject(typing.Generic[T]):
|
jpayne@7
|
25 """
|
jpayne@7
|
26 A wrapper class that hides a specific object behind a lock.
|
jpayne@7
|
27
|
jpayne@7
|
28 The goal here is to provide a simple way to protect access to an object
|
jpayne@7
|
29 that cannot safely be simultaneously accessed from multiple threads. The
|
jpayne@7
|
30 intended use of this class is simple: take hold of it with a context
|
jpayne@7
|
31 manager, which returns the protected object.
|
jpayne@7
|
32 """
|
jpayne@7
|
33
|
jpayne@7
|
34 def __init__(self, obj: T):
|
jpayne@7
|
35 self.lock = threading.RLock()
|
jpayne@7
|
36 self._obj = obj
|
jpayne@7
|
37
|
jpayne@7
|
38 def __enter__(self) -> T:
|
jpayne@7
|
39 self.lock.acquire()
|
jpayne@7
|
40 return self._obj
|
jpayne@7
|
41
|
jpayne@7
|
42 def __exit__(
|
jpayne@7
|
43 self,
|
jpayne@7
|
44 exc_type: type[BaseException] | None,
|
jpayne@7
|
45 exc_val: BaseException | None,
|
jpayne@7
|
46 exc_tb: types.TracebackType | None,
|
jpayne@7
|
47 ) -> None:
|
jpayne@7
|
48 self.lock.release()
|
jpayne@7
|
49
|
jpayne@7
|
50
|
jpayne@7
|
51 class HTTP2Connection(HTTPSConnection):
|
jpayne@7
|
52 def __init__(
|
jpayne@7
|
53 self, host: str, port: int | None = None, **kwargs: typing.Any
|
jpayne@7
|
54 ) -> None:
|
jpayne@7
|
55 self._h2_conn = self._new_h2_conn()
|
jpayne@7
|
56 self._h2_stream: int | None = None
|
jpayne@7
|
57 self._h2_headers: list[tuple[bytes, bytes]] = []
|
jpayne@7
|
58
|
jpayne@7
|
59 if "proxy" in kwargs or "proxy_config" in kwargs: # Defensive:
|
jpayne@7
|
60 raise NotImplementedError("Proxies aren't supported with HTTP/2")
|
jpayne@7
|
61
|
jpayne@7
|
62 super().__init__(host, port, **kwargs)
|
jpayne@7
|
63
|
jpayne@7
|
64 def _new_h2_conn(self) -> _LockedObject[h2.connection.H2Connection]:
|
jpayne@7
|
65 config = h2.config.H2Configuration(client_side=True)
|
jpayne@7
|
66 return _LockedObject(h2.connection.H2Connection(config=config))
|
jpayne@7
|
67
|
jpayne@7
|
68 def connect(self) -> None:
|
jpayne@7
|
69 super().connect()
|
jpayne@7
|
70
|
jpayne@7
|
71 with self._h2_conn as h2_conn:
|
jpayne@7
|
72 h2_conn.initiate_connection()
|
jpayne@7
|
73 self.sock.sendall(h2_conn.data_to_send())
|
jpayne@7
|
74
|
jpayne@7
|
75 def putrequest(
|
jpayne@7
|
76 self,
|
jpayne@7
|
77 method: str,
|
jpayne@7
|
78 url: str,
|
jpayne@7
|
79 skip_host: bool = False,
|
jpayne@7
|
80 skip_accept_encoding: bool = False,
|
jpayne@7
|
81 ) -> None:
|
jpayne@7
|
82 with self._h2_conn as h2_conn:
|
jpayne@7
|
83 self._request_url = url
|
jpayne@7
|
84 self._h2_stream = h2_conn.get_next_available_stream_id()
|
jpayne@7
|
85
|
jpayne@7
|
86 if ":" in self.host:
|
jpayne@7
|
87 authority = f"[{self.host}]:{self.port or 443}"
|
jpayne@7
|
88 else:
|
jpayne@7
|
89 authority = f"{self.host}:{self.port or 443}"
|
jpayne@7
|
90
|
jpayne@7
|
91 self._h2_headers.extend(
|
jpayne@7
|
92 (
|
jpayne@7
|
93 (b":scheme", b"https"),
|
jpayne@7
|
94 (b":method", method.encode()),
|
jpayne@7
|
95 (b":authority", authority.encode()),
|
jpayne@7
|
96 (b":path", url.encode()),
|
jpayne@7
|
97 )
|
jpayne@7
|
98 )
|
jpayne@7
|
99
|
jpayne@7
|
100 def putheader(self, header: str, *values: str) -> None: # type: ignore[override]
|
jpayne@7
|
101 for value in values:
|
jpayne@7
|
102 self._h2_headers.append(
|
jpayne@7
|
103 (header.encode("utf-8").lower(), value.encode("utf-8"))
|
jpayne@7
|
104 )
|
jpayne@7
|
105
|
jpayne@7
|
106 def endheaders(self) -> None: # type: ignore[override]
|
jpayne@7
|
107 with self._h2_conn as h2_conn:
|
jpayne@7
|
108 h2_conn.send_headers(
|
jpayne@7
|
109 stream_id=self._h2_stream,
|
jpayne@7
|
110 headers=self._h2_headers,
|
jpayne@7
|
111 end_stream=True,
|
jpayne@7
|
112 )
|
jpayne@7
|
113 if data_to_send := h2_conn.data_to_send():
|
jpayne@7
|
114 self.sock.sendall(data_to_send)
|
jpayne@7
|
115
|
jpayne@7
|
116 def send(self, data: bytes) -> None: # type: ignore[override] # Defensive:
|
jpayne@7
|
117 if not data:
|
jpayne@7
|
118 return
|
jpayne@7
|
119 raise NotImplementedError("Sending data isn't supported yet")
|
jpayne@7
|
120
|
jpayne@7
|
121 def getresponse( # type: ignore[override]
|
jpayne@7
|
122 self,
|
jpayne@7
|
123 ) -> HTTP2Response:
|
jpayne@7
|
124 status = None
|
jpayne@7
|
125 data = bytearray()
|
jpayne@7
|
126 with self._h2_conn as h2_conn:
|
jpayne@7
|
127 end_stream = False
|
jpayne@7
|
128 while not end_stream:
|
jpayne@7
|
129 # TODO: Arbitrary read value.
|
jpayne@7
|
130 if received_data := self.sock.recv(65535):
|
jpayne@7
|
131 events = h2_conn.receive_data(received_data)
|
jpayne@7
|
132 for event in events:
|
jpayne@7
|
133 if isinstance(event, h2.events.ResponseReceived):
|
jpayne@7
|
134 headers = HTTPHeaderDict()
|
jpayne@7
|
135 for header, value in event.headers:
|
jpayne@7
|
136 if header == b":status":
|
jpayne@7
|
137 status = int(value.decode())
|
jpayne@7
|
138 else:
|
jpayne@7
|
139 headers.add(
|
jpayne@7
|
140 header.decode("ascii"), value.decode("ascii")
|
jpayne@7
|
141 )
|
jpayne@7
|
142
|
jpayne@7
|
143 elif isinstance(event, h2.events.DataReceived):
|
jpayne@7
|
144 data += event.data
|
jpayne@7
|
145 h2_conn.acknowledge_received_data(
|
jpayne@7
|
146 event.flow_controlled_length, event.stream_id
|
jpayne@7
|
147 )
|
jpayne@7
|
148
|
jpayne@7
|
149 elif isinstance(event, h2.events.StreamEnded):
|
jpayne@7
|
150 end_stream = True
|
jpayne@7
|
151
|
jpayne@7
|
152 if data_to_send := h2_conn.data_to_send():
|
jpayne@7
|
153 self.sock.sendall(data_to_send)
|
jpayne@7
|
154
|
jpayne@7
|
155 # We always close to not have to handle connection management.
|
jpayne@7
|
156 self.close()
|
jpayne@7
|
157
|
jpayne@7
|
158 assert status is not None
|
jpayne@7
|
159 return HTTP2Response(
|
jpayne@7
|
160 status=status,
|
jpayne@7
|
161 headers=headers,
|
jpayne@7
|
162 request_url=self._request_url,
|
jpayne@7
|
163 data=bytes(data),
|
jpayne@7
|
164 )
|
jpayne@7
|
165
|
jpayne@7
|
166 def close(self) -> None:
|
jpayne@7
|
167 with self._h2_conn as h2_conn:
|
jpayne@7
|
168 try:
|
jpayne@7
|
169 h2_conn.close_connection()
|
jpayne@7
|
170 if data := h2_conn.data_to_send():
|
jpayne@7
|
171 self.sock.sendall(data)
|
jpayne@7
|
172 except Exception:
|
jpayne@7
|
173 pass
|
jpayne@7
|
174
|
jpayne@7
|
175 # Reset all our HTTP/2 connection state.
|
jpayne@7
|
176 self._h2_conn = self._new_h2_conn()
|
jpayne@7
|
177 self._h2_stream = None
|
jpayne@7
|
178 self._h2_headers = []
|
jpayne@7
|
179
|
jpayne@7
|
180 super().close()
|
jpayne@7
|
181
|
jpayne@7
|
182
|
jpayne@7
|
183 class HTTP2Response(BaseHTTPResponse):
|
jpayne@7
|
184 # TODO: This is a woefully incomplete response object, but works for non-streaming.
|
jpayne@7
|
185 def __init__(
|
jpayne@7
|
186 self,
|
jpayne@7
|
187 status: int,
|
jpayne@7
|
188 headers: HTTPHeaderDict,
|
jpayne@7
|
189 request_url: str,
|
jpayne@7
|
190 data: bytes,
|
jpayne@7
|
191 decode_content: bool = False, # TODO: support decoding
|
jpayne@7
|
192 ) -> None:
|
jpayne@7
|
193 super().__init__(
|
jpayne@7
|
194 status=status,
|
jpayne@7
|
195 headers=headers,
|
jpayne@7
|
196 # Following CPython, we map HTTP versions to major * 10 + minor integers
|
jpayne@7
|
197 version=20,
|
jpayne@7
|
198 # No reason phrase in HTTP/2
|
jpayne@7
|
199 reason=None,
|
jpayne@7
|
200 decode_content=decode_content,
|
jpayne@7
|
201 request_url=request_url,
|
jpayne@7
|
202 )
|
jpayne@7
|
203 self._data = data
|
jpayne@7
|
204 self.length_remaining = 0
|
jpayne@7
|
205
|
jpayne@7
|
206 @property
|
jpayne@7
|
207 def data(self) -> bytes:
|
jpayne@7
|
208 return self._data
|
jpayne@7
|
209
|
jpayne@7
|
210 def get_redirect_location(self) -> None:
|
jpayne@7
|
211 return None
|
jpayne@7
|
212
|
jpayne@7
|
213 def close(self) -> None:
|
jpayne@7
|
214 pass
|
jpayne@7
|
215
|
jpayne@7
|
216
|
jpayne@7
|
217 def inject_into_urllib3() -> None:
|
jpayne@7
|
218 HTTPSConnectionPool.ConnectionCls = HTTP2Connection
|
jpayne@7
|
219 urllib3.connection.HTTPSConnection = HTTP2Connection # type: ignore[misc]
|
jpayne@7
|
220
|
jpayne@7
|
221 # TODO: Offer 'http/1.1' as well, but for testing purposes this is handy.
|
jpayne@7
|
222 urllib3.util.ssl_.ALPN_PROTOCOLS = ["h2"]
|
jpayne@7
|
223
|
jpayne@7
|
224
|
jpayne@7
|
225 def extract_from_urllib3() -> None:
|
jpayne@7
|
226 HTTPSConnectionPool.ConnectionCls = orig_HTTPSConnection
|
jpayne@7
|
227 urllib3.connection.HTTPSConnection = orig_HTTPSConnection # type: ignore[misc]
|
jpayne@7
|
228
|
jpayne@7
|
229 urllib3.util.ssl_.ALPN_PROTOCOLS = ["http/1.1"]
|