jpayne@7: from __future__ import annotations jpayne@7: jpayne@7: import typing jpayne@7: from collections import OrderedDict jpayne@7: from enum import Enum, auto jpayne@7: from threading import RLock jpayne@7: jpayne@7: if typing.TYPE_CHECKING: jpayne@7: # We can only import Protocol if TYPE_CHECKING because it's a development jpayne@7: # dependency, and is not available at runtime. jpayne@7: from typing import Protocol jpayne@7: jpayne@7: from typing_extensions import Self jpayne@7: jpayne@7: class HasGettableStringKeys(Protocol): jpayne@7: def keys(self) -> typing.Iterator[str]: jpayne@7: ... jpayne@7: jpayne@7: def __getitem__(self, key: str) -> str: jpayne@7: ... jpayne@7: jpayne@7: jpayne@7: __all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"] jpayne@7: jpayne@7: jpayne@7: # Key type jpayne@7: _KT = typing.TypeVar("_KT") jpayne@7: # Value type jpayne@7: _VT = typing.TypeVar("_VT") jpayne@7: # Default type jpayne@7: _DT = typing.TypeVar("_DT") jpayne@7: jpayne@7: ValidHTTPHeaderSource = typing.Union[ jpayne@7: "HTTPHeaderDict", jpayne@7: typing.Mapping[str, str], jpayne@7: typing.Iterable[typing.Tuple[str, str]], jpayne@7: "HasGettableStringKeys", jpayne@7: ] jpayne@7: jpayne@7: jpayne@7: class _Sentinel(Enum): jpayne@7: not_passed = auto() jpayne@7: jpayne@7: jpayne@7: def ensure_can_construct_http_header_dict( jpayne@7: potential: object, jpayne@7: ) -> ValidHTTPHeaderSource | None: jpayne@7: if isinstance(potential, HTTPHeaderDict): jpayne@7: return potential jpayne@7: elif isinstance(potential, typing.Mapping): jpayne@7: # Full runtime checking of the contents of a Mapping is expensive, so for the jpayne@7: # purposes of typechecking, we assume that any Mapping is the right shape. jpayne@7: return typing.cast(typing.Mapping[str, str], potential) jpayne@7: elif isinstance(potential, typing.Iterable): jpayne@7: # Similarly to Mapping, full runtime checking of the contents of an Iterable is jpayne@7: # expensive, so for the purposes of typechecking, we assume that any Iterable jpayne@7: # is the right shape. jpayne@7: return typing.cast(typing.Iterable[typing.Tuple[str, str]], potential) jpayne@7: elif hasattr(potential, "keys") and hasattr(potential, "__getitem__"): jpayne@7: return typing.cast("HasGettableStringKeys", potential) jpayne@7: else: jpayne@7: return None jpayne@7: jpayne@7: jpayne@7: class RecentlyUsedContainer(typing.Generic[_KT, _VT], typing.MutableMapping[_KT, _VT]): jpayne@7: """ jpayne@7: Provides a thread-safe dict-like container which maintains up to jpayne@7: ``maxsize`` keys while throwing away the least-recently-used keys beyond jpayne@7: ``maxsize``. jpayne@7: jpayne@7: :param maxsize: jpayne@7: Maximum number of recent elements to retain. jpayne@7: jpayne@7: :param dispose_func: jpayne@7: Every time an item is evicted from the container, jpayne@7: ``dispose_func(value)`` is called. Callback which will get called jpayne@7: """ jpayne@7: jpayne@7: _container: typing.OrderedDict[_KT, _VT] jpayne@7: _maxsize: int jpayne@7: dispose_func: typing.Callable[[_VT], None] | None jpayne@7: lock: RLock jpayne@7: jpayne@7: def __init__( jpayne@7: self, jpayne@7: maxsize: int = 10, jpayne@7: dispose_func: typing.Callable[[_VT], None] | None = None, jpayne@7: ) -> None: jpayne@7: super().__init__() jpayne@7: self._maxsize = maxsize jpayne@7: self.dispose_func = dispose_func jpayne@7: self._container = OrderedDict() jpayne@7: self.lock = RLock() jpayne@7: jpayne@7: def __getitem__(self, key: _KT) -> _VT: jpayne@7: # Re-insert the item, moving it to the end of the eviction line. jpayne@7: with self.lock: jpayne@7: item = self._container.pop(key) jpayne@7: self._container[key] = item jpayne@7: return item jpayne@7: jpayne@7: def __setitem__(self, key: _KT, value: _VT) -> None: jpayne@7: evicted_item = None jpayne@7: with self.lock: jpayne@7: # Possibly evict the existing value of 'key' jpayne@7: try: jpayne@7: # If the key exists, we'll overwrite it, which won't change the jpayne@7: # size of the pool. Because accessing a key should move it to jpayne@7: # the end of the eviction line, we pop it out first. jpayne@7: evicted_item = key, self._container.pop(key) jpayne@7: self._container[key] = value jpayne@7: except KeyError: jpayne@7: # When the key does not exist, we insert the value first so that jpayne@7: # evicting works in all cases, including when self._maxsize is 0 jpayne@7: self._container[key] = value jpayne@7: if len(self._container) > self._maxsize: jpayne@7: # If we didn't evict an existing value, and we've hit our maximum jpayne@7: # size, then we have to evict the least recently used item from jpayne@7: # the beginning of the container. jpayne@7: evicted_item = self._container.popitem(last=False) jpayne@7: jpayne@7: # After releasing the lock on the pool, dispose of any evicted value. jpayne@7: if evicted_item is not None and self.dispose_func: jpayne@7: _, evicted_value = evicted_item jpayne@7: self.dispose_func(evicted_value) jpayne@7: jpayne@7: def __delitem__(self, key: _KT) -> None: jpayne@7: with self.lock: jpayne@7: value = self._container.pop(key) jpayne@7: jpayne@7: if self.dispose_func: jpayne@7: self.dispose_func(value) jpayne@7: jpayne@7: def __len__(self) -> int: jpayne@7: with self.lock: jpayne@7: return len(self._container) jpayne@7: jpayne@7: def __iter__(self) -> typing.NoReturn: jpayne@7: raise NotImplementedError( jpayne@7: "Iteration over this class is unlikely to be threadsafe." jpayne@7: ) jpayne@7: jpayne@7: def clear(self) -> None: jpayne@7: with self.lock: jpayne@7: # Copy pointers to all values, then wipe the mapping jpayne@7: values = list(self._container.values()) jpayne@7: self._container.clear() jpayne@7: jpayne@7: if self.dispose_func: jpayne@7: for value in values: jpayne@7: self.dispose_func(value) jpayne@7: jpayne@7: def keys(self) -> set[_KT]: # type: ignore[override] jpayne@7: with self.lock: jpayne@7: return set(self._container.keys()) jpayne@7: jpayne@7: jpayne@7: class HTTPHeaderDictItemView(typing.Set[typing.Tuple[str, str]]): jpayne@7: """ jpayne@7: HTTPHeaderDict is unusual for a Mapping[str, str] in that it has two modes of jpayne@7: address. jpayne@7: jpayne@7: If we directly try to get an item with a particular name, we will get a string jpayne@7: back that is the concatenated version of all the values: jpayne@7: jpayne@7: >>> d['X-Header-Name'] jpayne@7: 'Value1, Value2, Value3' jpayne@7: jpayne@7: However, if we iterate over an HTTPHeaderDict's items, we will optionally combine jpayne@7: these values based on whether combine=True was called when building up the dictionary jpayne@7: jpayne@7: >>> d = HTTPHeaderDict({"A": "1", "B": "foo"}) jpayne@7: >>> d.add("A", "2", combine=True) jpayne@7: >>> d.add("B", "bar") jpayne@7: >>> list(d.items()) jpayne@7: [ jpayne@7: ('A', '1, 2'), jpayne@7: ('B', 'foo'), jpayne@7: ('B', 'bar'), jpayne@7: ] jpayne@7: jpayne@7: This class conforms to the interface required by the MutableMapping ABC while jpayne@7: also giving us the nonstandard iteration behavior we want; items with duplicate jpayne@7: keys, ordered by time of first insertion. jpayne@7: """ jpayne@7: jpayne@7: _headers: HTTPHeaderDict jpayne@7: jpayne@7: def __init__(self, headers: HTTPHeaderDict) -> None: jpayne@7: self._headers = headers jpayne@7: jpayne@7: def __len__(self) -> int: jpayne@7: return len(list(self._headers.iteritems())) jpayne@7: jpayne@7: def __iter__(self) -> typing.Iterator[tuple[str, str]]: jpayne@7: return self._headers.iteritems() jpayne@7: jpayne@7: def __contains__(self, item: object) -> bool: jpayne@7: if isinstance(item, tuple) and len(item) == 2: jpayne@7: passed_key, passed_val = item jpayne@7: if isinstance(passed_key, str) and isinstance(passed_val, str): jpayne@7: return self._headers._has_value_for_header(passed_key, passed_val) jpayne@7: return False jpayne@7: jpayne@7: jpayne@7: class HTTPHeaderDict(typing.MutableMapping[str, str]): jpayne@7: """ jpayne@7: :param headers: jpayne@7: An iterable of field-value pairs. Must not contain multiple field names jpayne@7: when compared case-insensitively. jpayne@7: jpayne@7: :param kwargs: jpayne@7: Additional field-value pairs to pass in to ``dict.update``. jpayne@7: jpayne@7: A ``dict`` like container for storing HTTP Headers. jpayne@7: jpayne@7: Field names are stored and compared case-insensitively in compliance with jpayne@7: RFC 7230. Iteration provides the first case-sensitive key seen for each jpayne@7: case-insensitive pair. jpayne@7: jpayne@7: Using ``__setitem__`` syntax overwrites fields that compare equal jpayne@7: case-insensitively in order to maintain ``dict``'s api. For fields that jpayne@7: compare equal, instead create a new ``HTTPHeaderDict`` and use ``.add`` jpayne@7: in a loop. jpayne@7: jpayne@7: If multiple fields that are equal case-insensitively are passed to the jpayne@7: constructor or ``.update``, the behavior is undefined and some will be jpayne@7: lost. jpayne@7: jpayne@7: >>> headers = HTTPHeaderDict() jpayne@7: >>> headers.add('Set-Cookie', 'foo=bar') jpayne@7: >>> headers.add('set-cookie', 'baz=quxx') jpayne@7: >>> headers['content-length'] = '7' jpayne@7: >>> headers['SET-cookie'] jpayne@7: 'foo=bar, baz=quxx' jpayne@7: >>> headers['Content-Length'] jpayne@7: '7' jpayne@7: """ jpayne@7: jpayne@7: _container: typing.MutableMapping[str, list[str]] jpayne@7: jpayne@7: def __init__(self, headers: ValidHTTPHeaderSource | None = None, **kwargs: str): jpayne@7: super().__init__() jpayne@7: self._container = {} # 'dict' is insert-ordered jpayne@7: if headers is not None: jpayne@7: if isinstance(headers, HTTPHeaderDict): jpayne@7: self._copy_from(headers) jpayne@7: else: jpayne@7: self.extend(headers) jpayne@7: if kwargs: jpayne@7: self.extend(kwargs) jpayne@7: jpayne@7: def __setitem__(self, key: str, val: str) -> None: jpayne@7: # avoid a bytes/str comparison by decoding before httplib jpayne@7: if isinstance(key, bytes): jpayne@7: key = key.decode("latin-1") jpayne@7: self._container[key.lower()] = [key, val] jpayne@7: jpayne@7: def __getitem__(self, key: str) -> str: jpayne@7: val = self._container[key.lower()] jpayne@7: return ", ".join(val[1:]) jpayne@7: jpayne@7: def __delitem__(self, key: str) -> None: jpayne@7: del self._container[key.lower()] jpayne@7: jpayne@7: def __contains__(self, key: object) -> bool: jpayne@7: if isinstance(key, str): jpayne@7: return key.lower() in self._container jpayne@7: return False jpayne@7: jpayne@7: def setdefault(self, key: str, default: str = "") -> str: jpayne@7: return super().setdefault(key, default) jpayne@7: jpayne@7: def __eq__(self, other: object) -> bool: jpayne@7: maybe_constructable = ensure_can_construct_http_header_dict(other) jpayne@7: if maybe_constructable is None: jpayne@7: return False jpayne@7: else: jpayne@7: other_as_http_header_dict = type(self)(maybe_constructable) jpayne@7: jpayne@7: return {k.lower(): v for k, v in self.itermerged()} == { jpayne@7: k.lower(): v for k, v in other_as_http_header_dict.itermerged() jpayne@7: } jpayne@7: jpayne@7: def __ne__(self, other: object) -> bool: jpayne@7: return not self.__eq__(other) jpayne@7: jpayne@7: def __len__(self) -> int: jpayne@7: return len(self._container) jpayne@7: jpayne@7: def __iter__(self) -> typing.Iterator[str]: jpayne@7: # Only provide the originally cased names jpayne@7: for vals in self._container.values(): jpayne@7: yield vals[0] jpayne@7: jpayne@7: def discard(self, key: str) -> None: jpayne@7: try: jpayne@7: del self[key] jpayne@7: except KeyError: jpayne@7: pass jpayne@7: jpayne@7: def add(self, key: str, val: str, *, combine: bool = False) -> None: jpayne@7: """Adds a (name, value) pair, doesn't overwrite the value if it already jpayne@7: exists. jpayne@7: jpayne@7: If this is called with combine=True, instead of adding a new header value jpayne@7: as a distinct item during iteration, this will instead append the value to jpayne@7: any existing header value with a comma. If no existing header value exists jpayne@7: for the key, then the value will simply be added, ignoring the combine parameter. jpayne@7: jpayne@7: >>> headers = HTTPHeaderDict(foo='bar') jpayne@7: >>> headers.add('Foo', 'baz') jpayne@7: >>> headers['foo'] jpayne@7: 'bar, baz' jpayne@7: >>> list(headers.items()) jpayne@7: [('foo', 'bar'), ('foo', 'baz')] jpayne@7: >>> headers.add('foo', 'quz', combine=True) jpayne@7: >>> list(headers.items()) jpayne@7: [('foo', 'bar, baz, quz')] jpayne@7: """ jpayne@7: # avoid a bytes/str comparison by decoding before httplib jpayne@7: if isinstance(key, bytes): jpayne@7: key = key.decode("latin-1") jpayne@7: key_lower = key.lower() jpayne@7: new_vals = [key, val] jpayne@7: # Keep the common case aka no item present as fast as possible jpayne@7: vals = self._container.setdefault(key_lower, new_vals) jpayne@7: if new_vals is not vals: jpayne@7: # if there are values here, then there is at least the initial jpayne@7: # key/value pair jpayne@7: assert len(vals) >= 2 jpayne@7: if combine: jpayne@7: vals[-1] = vals[-1] + ", " + val jpayne@7: else: jpayne@7: vals.append(val) jpayne@7: jpayne@7: def extend(self, *args: ValidHTTPHeaderSource, **kwargs: str) -> None: jpayne@7: """Generic import function for any type of header-like object. jpayne@7: Adapted version of MutableMapping.update in order to insert items jpayne@7: with self.add instead of self.__setitem__ jpayne@7: """ jpayne@7: if len(args) > 1: jpayne@7: raise TypeError( jpayne@7: f"extend() takes at most 1 positional arguments ({len(args)} given)" jpayne@7: ) jpayne@7: other = args[0] if len(args) >= 1 else () jpayne@7: jpayne@7: if isinstance(other, HTTPHeaderDict): jpayne@7: for key, val in other.iteritems(): jpayne@7: self.add(key, val) jpayne@7: elif isinstance(other, typing.Mapping): jpayne@7: for key, val in other.items(): jpayne@7: self.add(key, val) jpayne@7: elif isinstance(other, typing.Iterable): jpayne@7: other = typing.cast(typing.Iterable[typing.Tuple[str, str]], other) jpayne@7: for key, value in other: jpayne@7: self.add(key, value) jpayne@7: elif hasattr(other, "keys") and hasattr(other, "__getitem__"): jpayne@7: # THIS IS NOT A TYPESAFE BRANCH jpayne@7: # In this branch, the object has a `keys` attr but is not a Mapping or any of jpayne@7: # the other types indicated in the method signature. We do some stuff with jpayne@7: # it as though it partially implements the Mapping interface, but we're not jpayne@7: # doing that stuff safely AT ALL. jpayne@7: for key in other.keys(): jpayne@7: self.add(key, other[key]) jpayne@7: jpayne@7: for key, value in kwargs.items(): jpayne@7: self.add(key, value) jpayne@7: jpayne@7: @typing.overload jpayne@7: def getlist(self, key: str) -> list[str]: jpayne@7: ... jpayne@7: jpayne@7: @typing.overload jpayne@7: def getlist(self, key: str, default: _DT) -> list[str] | _DT: jpayne@7: ... jpayne@7: jpayne@7: def getlist( jpayne@7: self, key: str, default: _Sentinel | _DT = _Sentinel.not_passed jpayne@7: ) -> list[str] | _DT: jpayne@7: """Returns a list of all the values for the named field. Returns an jpayne@7: empty list if the key doesn't exist.""" jpayne@7: try: jpayne@7: vals = self._container[key.lower()] jpayne@7: except KeyError: jpayne@7: if default is _Sentinel.not_passed: jpayne@7: # _DT is unbound; empty list is instance of List[str] jpayne@7: return [] jpayne@7: # _DT is bound; default is instance of _DT jpayne@7: return default jpayne@7: else: jpayne@7: # _DT may or may not be bound; vals[1:] is instance of List[str], which jpayne@7: # meets our external interface requirement of `Union[List[str], _DT]`. jpayne@7: return vals[1:] jpayne@7: jpayne@7: def _prepare_for_method_change(self) -> Self: jpayne@7: """ jpayne@7: Remove content-specific header fields before changing the request jpayne@7: method to GET or HEAD according to RFC 9110, Section 15.4. jpayne@7: """ jpayne@7: content_specific_headers = [ jpayne@7: "Content-Encoding", jpayne@7: "Content-Language", jpayne@7: "Content-Location", jpayne@7: "Content-Type", jpayne@7: "Content-Length", jpayne@7: "Digest", jpayne@7: "Last-Modified", jpayne@7: ] jpayne@7: for header in content_specific_headers: jpayne@7: self.discard(header) jpayne@7: return self jpayne@7: jpayne@7: # Backwards compatibility for httplib jpayne@7: getheaders = getlist jpayne@7: getallmatchingheaders = getlist jpayne@7: iget = getlist jpayne@7: jpayne@7: # Backwards compatibility for http.cookiejar jpayne@7: get_all = getlist jpayne@7: jpayne@7: def __repr__(self) -> str: jpayne@7: return f"{type(self).__name__}({dict(self.itermerged())})" jpayne@7: jpayne@7: def _copy_from(self, other: HTTPHeaderDict) -> None: jpayne@7: for key in other: jpayne@7: val = other.getlist(key) jpayne@7: self._container[key.lower()] = [key, *val] jpayne@7: jpayne@7: def copy(self) -> HTTPHeaderDict: jpayne@7: clone = type(self)() jpayne@7: clone._copy_from(self) jpayne@7: return clone jpayne@7: jpayne@7: def iteritems(self) -> typing.Iterator[tuple[str, str]]: jpayne@7: """Iterate over all header lines, including duplicate ones.""" jpayne@7: for key in self: jpayne@7: vals = self._container[key.lower()] jpayne@7: for val in vals[1:]: jpayne@7: yield vals[0], val jpayne@7: jpayne@7: def itermerged(self) -> typing.Iterator[tuple[str, str]]: jpayne@7: """Iterate over all headers, merging duplicate ones together.""" jpayne@7: for key in self: jpayne@7: val = self._container[key.lower()] jpayne@7: yield val[0], ", ".join(val[1:]) jpayne@7: jpayne@7: def items(self) -> HTTPHeaderDictItemView: # type: ignore[override] jpayne@7: return HTTPHeaderDictItemView(self) jpayne@7: jpayne@7: def _has_value_for_header(self, header_name: str, potential_value: str) -> bool: jpayne@7: if header_name in self: jpayne@7: return potential_value in self._container[header_name.lower()][1:] jpayne@7: return False jpayne@7: jpayne@7: def __ior__(self, other: object) -> HTTPHeaderDict: jpayne@7: # Supports extending a header dict in-place using operator |= jpayne@7: # combining items with add instead of __setitem__ jpayne@7: maybe_constructable = ensure_can_construct_http_header_dict(other) jpayne@7: if maybe_constructable is None: jpayne@7: return NotImplemented jpayne@7: self.extend(maybe_constructable) jpayne@7: return self jpayne@7: jpayne@7: def __or__(self, other: object) -> HTTPHeaderDict: jpayne@7: # Supports merging header dicts using operator | jpayne@7: # combining items with add instead of __setitem__ jpayne@7: maybe_constructable = ensure_can_construct_http_header_dict(other) jpayne@7: if maybe_constructable is None: jpayne@7: return NotImplemented jpayne@7: result = self.copy() jpayne@7: result.extend(maybe_constructable) jpayne@7: return result jpayne@7: jpayne@7: def __ror__(self, other: object) -> HTTPHeaderDict: jpayne@7: # Supports merging header dicts using operator | when other is on left side jpayne@7: # combining items with add instead of __setitem__ jpayne@7: maybe_constructable = ensure_can_construct_http_header_dict(other) jpayne@7: if maybe_constructable is None: jpayne@7: return NotImplemented jpayne@7: result = type(self)(maybe_constructable) jpayne@7: result.extend(self) jpayne@7: return result