jpayne@7: from __future__ import annotations jpayne@7: jpayne@7: import threading jpayne@7: import types jpayne@7: import typing jpayne@7: jpayne@7: import h2.config # type: ignore[import-untyped] jpayne@7: import h2.connection # type: ignore[import-untyped] jpayne@7: import h2.events # type: ignore[import-untyped] jpayne@7: jpayne@7: import urllib3.connection jpayne@7: import urllib3.util.ssl_ jpayne@7: from urllib3.response import BaseHTTPResponse jpayne@7: jpayne@7: from ._collections import HTTPHeaderDict jpayne@7: from .connection import HTTPSConnection jpayne@7: from .connectionpool import HTTPSConnectionPool jpayne@7: jpayne@7: orig_HTTPSConnection = HTTPSConnection jpayne@7: jpayne@7: T = typing.TypeVar("T") jpayne@7: jpayne@7: jpayne@7: class _LockedObject(typing.Generic[T]): jpayne@7: """ jpayne@7: A wrapper class that hides a specific object behind a lock. jpayne@7: jpayne@7: The goal here is to provide a simple way to protect access to an object jpayne@7: that cannot safely be simultaneously accessed from multiple threads. The jpayne@7: intended use of this class is simple: take hold of it with a context jpayne@7: manager, which returns the protected object. jpayne@7: """ jpayne@7: jpayne@7: def __init__(self, obj: T): jpayne@7: self.lock = threading.RLock() jpayne@7: self._obj = obj jpayne@7: jpayne@7: def __enter__(self) -> T: jpayne@7: self.lock.acquire() jpayne@7: return self._obj jpayne@7: jpayne@7: def __exit__( jpayne@7: self, jpayne@7: exc_type: type[BaseException] | None, jpayne@7: exc_val: BaseException | None, jpayne@7: exc_tb: types.TracebackType | None, jpayne@7: ) -> None: jpayne@7: self.lock.release() jpayne@7: jpayne@7: jpayne@7: class HTTP2Connection(HTTPSConnection): jpayne@7: def __init__( jpayne@7: self, host: str, port: int | None = None, **kwargs: typing.Any jpayne@7: ) -> None: jpayne@7: self._h2_conn = self._new_h2_conn() jpayne@7: self._h2_stream: int | None = None jpayne@7: self._h2_headers: list[tuple[bytes, bytes]] = [] jpayne@7: jpayne@7: if "proxy" in kwargs or "proxy_config" in kwargs: # Defensive: jpayne@7: raise NotImplementedError("Proxies aren't supported with HTTP/2") jpayne@7: jpayne@7: super().__init__(host, port, **kwargs) jpayne@7: jpayne@7: def _new_h2_conn(self) -> _LockedObject[h2.connection.H2Connection]: jpayne@7: config = h2.config.H2Configuration(client_side=True) jpayne@7: return _LockedObject(h2.connection.H2Connection(config=config)) jpayne@7: jpayne@7: def connect(self) -> None: jpayne@7: super().connect() jpayne@7: jpayne@7: with self._h2_conn as h2_conn: jpayne@7: h2_conn.initiate_connection() jpayne@7: self.sock.sendall(h2_conn.data_to_send()) jpayne@7: jpayne@7: def putrequest( jpayne@7: self, jpayne@7: method: str, jpayne@7: url: str, jpayne@7: skip_host: bool = False, jpayne@7: skip_accept_encoding: bool = False, jpayne@7: ) -> None: jpayne@7: with self._h2_conn as h2_conn: jpayne@7: self._request_url = url jpayne@7: self._h2_stream = h2_conn.get_next_available_stream_id() jpayne@7: jpayne@7: if ":" in self.host: jpayne@7: authority = f"[{self.host}]:{self.port or 443}" jpayne@7: else: jpayne@7: authority = f"{self.host}:{self.port or 443}" jpayne@7: jpayne@7: self._h2_headers.extend( jpayne@7: ( jpayne@7: (b":scheme", b"https"), jpayne@7: (b":method", method.encode()), jpayne@7: (b":authority", authority.encode()), jpayne@7: (b":path", url.encode()), jpayne@7: ) jpayne@7: ) jpayne@7: jpayne@7: def putheader(self, header: str, *values: str) -> None: # type: ignore[override] jpayne@7: for value in values: jpayne@7: self._h2_headers.append( jpayne@7: (header.encode("utf-8").lower(), value.encode("utf-8")) jpayne@7: ) jpayne@7: jpayne@7: def endheaders(self) -> None: # type: ignore[override] jpayne@7: with self._h2_conn as h2_conn: jpayne@7: h2_conn.send_headers( jpayne@7: stream_id=self._h2_stream, jpayne@7: headers=self._h2_headers, jpayne@7: end_stream=True, jpayne@7: ) jpayne@7: if data_to_send := h2_conn.data_to_send(): jpayne@7: self.sock.sendall(data_to_send) jpayne@7: jpayne@7: def send(self, data: bytes) -> None: # type: ignore[override] # Defensive: jpayne@7: if not data: jpayne@7: return jpayne@7: raise NotImplementedError("Sending data isn't supported yet") jpayne@7: jpayne@7: def getresponse( # type: ignore[override] jpayne@7: self, jpayne@7: ) -> HTTP2Response: jpayne@7: status = None jpayne@7: data = bytearray() jpayne@7: with self._h2_conn as h2_conn: jpayne@7: end_stream = False jpayne@7: while not end_stream: jpayne@7: # TODO: Arbitrary read value. jpayne@7: if received_data := self.sock.recv(65535): jpayne@7: events = h2_conn.receive_data(received_data) jpayne@7: for event in events: jpayne@7: if isinstance(event, h2.events.ResponseReceived): jpayne@7: headers = HTTPHeaderDict() jpayne@7: for header, value in event.headers: jpayne@7: if header == b":status": jpayne@7: status = int(value.decode()) jpayne@7: else: jpayne@7: headers.add( jpayne@7: header.decode("ascii"), value.decode("ascii") jpayne@7: ) jpayne@7: jpayne@7: elif isinstance(event, h2.events.DataReceived): jpayne@7: data += event.data jpayne@7: h2_conn.acknowledge_received_data( jpayne@7: event.flow_controlled_length, event.stream_id jpayne@7: ) jpayne@7: jpayne@7: elif isinstance(event, h2.events.StreamEnded): jpayne@7: end_stream = True jpayne@7: jpayne@7: if data_to_send := h2_conn.data_to_send(): jpayne@7: self.sock.sendall(data_to_send) jpayne@7: jpayne@7: # We always close to not have to handle connection management. jpayne@7: self.close() jpayne@7: jpayne@7: assert status is not None jpayne@7: return HTTP2Response( jpayne@7: status=status, jpayne@7: headers=headers, jpayne@7: request_url=self._request_url, jpayne@7: data=bytes(data), jpayne@7: ) jpayne@7: jpayne@7: def close(self) -> None: jpayne@7: with self._h2_conn as h2_conn: jpayne@7: try: jpayne@7: h2_conn.close_connection() jpayne@7: if data := h2_conn.data_to_send(): jpayne@7: self.sock.sendall(data) jpayne@7: except Exception: jpayne@7: pass jpayne@7: jpayne@7: # Reset all our HTTP/2 connection state. jpayne@7: self._h2_conn = self._new_h2_conn() jpayne@7: self._h2_stream = None jpayne@7: self._h2_headers = [] jpayne@7: jpayne@7: super().close() jpayne@7: jpayne@7: jpayne@7: class HTTP2Response(BaseHTTPResponse): jpayne@7: # TODO: This is a woefully incomplete response object, but works for non-streaming. jpayne@7: def __init__( jpayne@7: self, jpayne@7: status: int, jpayne@7: headers: HTTPHeaderDict, jpayne@7: request_url: str, jpayne@7: data: bytes, jpayne@7: decode_content: bool = False, # TODO: support decoding jpayne@7: ) -> None: jpayne@7: super().__init__( jpayne@7: status=status, jpayne@7: headers=headers, jpayne@7: # Following CPython, we map HTTP versions to major * 10 + minor integers jpayne@7: version=20, jpayne@7: # No reason phrase in HTTP/2 jpayne@7: reason=None, jpayne@7: decode_content=decode_content, jpayne@7: request_url=request_url, jpayne@7: ) jpayne@7: self._data = data jpayne@7: self.length_remaining = 0 jpayne@7: jpayne@7: @property jpayne@7: def data(self) -> bytes: jpayne@7: return self._data jpayne@7: jpayne@7: def get_redirect_location(self) -> None: jpayne@7: return None jpayne@7: jpayne@7: def close(self) -> None: jpayne@7: pass jpayne@7: jpayne@7: jpayne@7: def inject_into_urllib3() -> None: jpayne@7: HTTPSConnectionPool.ConnectionCls = HTTP2Connection jpayne@7: urllib3.connection.HTTPSConnection = HTTP2Connection # type: ignore[misc] jpayne@7: jpayne@7: # TODO: Offer 'http/1.1' as well, but for testing purposes this is handy. jpayne@7: urllib3.util.ssl_.ALPN_PROTOCOLS = ["h2"] jpayne@7: jpayne@7: jpayne@7: def extract_from_urllib3() -> None: jpayne@7: HTTPSConnectionPool.ConnectionCls = orig_HTTPSConnection jpayne@7: urllib3.connection.HTTPSConnection = orig_HTTPSConnection # type: ignore[misc] jpayne@7: jpayne@7: urllib3.util.ssl_.ALPN_PROTOCOLS = ["http/1.1"]