Mercurial > repos > jpayne > bioproject_to_srr_2
diff urllib3/http2.py @ 7:5eb2d5e3bf22
planemo upload for repository https://toolrepo.galaxytrakr.org/view/jpayne/bioproject_to_srr_2/556cac4fb538
author | jpayne |
---|---|
date | Sun, 05 May 2024 23:32:17 -0400 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/urllib3/http2.py Sun May 05 23:32:17 2024 -0400 @@ -0,0 +1,229 @@ +from __future__ import annotations + +import threading +import types +import typing + +import h2.config # type: ignore[import-untyped] +import h2.connection # type: ignore[import-untyped] +import h2.events # type: ignore[import-untyped] + +import urllib3.connection +import urllib3.util.ssl_ +from urllib3.response import BaseHTTPResponse + +from ._collections import HTTPHeaderDict +from .connection import HTTPSConnection +from .connectionpool import HTTPSConnectionPool + +orig_HTTPSConnection = HTTPSConnection + +T = typing.TypeVar("T") + + +class _LockedObject(typing.Generic[T]): + """ + A wrapper class that hides a specific object behind a lock. + + The goal here is to provide a simple way to protect access to an object + that cannot safely be simultaneously accessed from multiple threads. The + intended use of this class is simple: take hold of it with a context + manager, which returns the protected object. + """ + + def __init__(self, obj: T): + self.lock = threading.RLock() + self._obj = obj + + def __enter__(self) -> T: + self.lock.acquire() + return self._obj + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None, + ) -> None: + self.lock.release() + + +class HTTP2Connection(HTTPSConnection): + def __init__( + self, host: str, port: int | None = None, **kwargs: typing.Any + ) -> None: + self._h2_conn = self._new_h2_conn() + self._h2_stream: int | None = None + self._h2_headers: list[tuple[bytes, bytes]] = [] + + if "proxy" in kwargs or "proxy_config" in kwargs: # Defensive: + raise NotImplementedError("Proxies aren't supported with HTTP/2") + + super().__init__(host, port, **kwargs) + + def _new_h2_conn(self) -> _LockedObject[h2.connection.H2Connection]: + config = h2.config.H2Configuration(client_side=True) + return _LockedObject(h2.connection.H2Connection(config=config)) + + def connect(self) -> None: + super().connect() + + with self._h2_conn as h2_conn: + h2_conn.initiate_connection() + self.sock.sendall(h2_conn.data_to_send()) + + def putrequest( + self, + method: str, + url: str, + skip_host: bool = False, + skip_accept_encoding: bool = False, + ) -> None: + with self._h2_conn as h2_conn: + self._request_url = url + self._h2_stream = h2_conn.get_next_available_stream_id() + + if ":" in self.host: + authority = f"[{self.host}]:{self.port or 443}" + else: + authority = f"{self.host}:{self.port or 443}" + + self._h2_headers.extend( + ( + (b":scheme", b"https"), + (b":method", method.encode()), + (b":authority", authority.encode()), + (b":path", url.encode()), + ) + ) + + def putheader(self, header: str, *values: str) -> None: # type: ignore[override] + for value in values: + self._h2_headers.append( + (header.encode("utf-8").lower(), value.encode("utf-8")) + ) + + def endheaders(self) -> None: # type: ignore[override] + with self._h2_conn as h2_conn: + h2_conn.send_headers( + stream_id=self._h2_stream, + headers=self._h2_headers, + end_stream=True, + ) + if data_to_send := h2_conn.data_to_send(): + self.sock.sendall(data_to_send) + + def send(self, data: bytes) -> None: # type: ignore[override] # Defensive: + if not data: + return + raise NotImplementedError("Sending data isn't supported yet") + + def getresponse( # type: ignore[override] + self, + ) -> HTTP2Response: + status = None + data = bytearray() + with self._h2_conn as h2_conn: + end_stream = False + while not end_stream: + # TODO: Arbitrary read value. + if received_data := self.sock.recv(65535): + events = h2_conn.receive_data(received_data) + for event in events: + if isinstance(event, h2.events.ResponseReceived): + headers = HTTPHeaderDict() + for header, value in event.headers: + if header == b":status": + status = int(value.decode()) + else: + headers.add( + header.decode("ascii"), value.decode("ascii") + ) + + elif isinstance(event, h2.events.DataReceived): + data += event.data + h2_conn.acknowledge_received_data( + event.flow_controlled_length, event.stream_id + ) + + elif isinstance(event, h2.events.StreamEnded): + end_stream = True + + if data_to_send := h2_conn.data_to_send(): + self.sock.sendall(data_to_send) + + # We always close to not have to handle connection management. + self.close() + + assert status is not None + return HTTP2Response( + status=status, + headers=headers, + request_url=self._request_url, + data=bytes(data), + ) + + def close(self) -> None: + with self._h2_conn as h2_conn: + try: + h2_conn.close_connection() + if data := h2_conn.data_to_send(): + self.sock.sendall(data) + except Exception: + pass + + # Reset all our HTTP/2 connection state. + self._h2_conn = self._new_h2_conn() + self._h2_stream = None + self._h2_headers = [] + + super().close() + + +class HTTP2Response(BaseHTTPResponse): + # TODO: This is a woefully incomplete response object, but works for non-streaming. + def __init__( + self, + status: int, + headers: HTTPHeaderDict, + request_url: str, + data: bytes, + decode_content: bool = False, # TODO: support decoding + ) -> None: + super().__init__( + status=status, + headers=headers, + # Following CPython, we map HTTP versions to major * 10 + minor integers + version=20, + # No reason phrase in HTTP/2 + reason=None, + decode_content=decode_content, + request_url=request_url, + ) + self._data = data + self.length_remaining = 0 + + @property + def data(self) -> bytes: + return self._data + + def get_redirect_location(self) -> None: + return None + + def close(self) -> None: + pass + + +def inject_into_urllib3() -> None: + HTTPSConnectionPool.ConnectionCls = HTTP2Connection + urllib3.connection.HTTPSConnection = HTTP2Connection # type: ignore[misc] + + # TODO: Offer 'http/1.1' as well, but for testing purposes this is handy. + urllib3.util.ssl_.ALPN_PROTOCOLS = ["h2"] + + +def extract_from_urllib3() -> None: + HTTPSConnectionPool.ConnectionCls = orig_HTTPSConnection + urllib3.connection.HTTPSConnection = orig_HTTPSConnection # type: ignore[misc] + + urllib3.util.ssl_.ALPN_PROTOCOLS = ["http/1.1"]