jpayne@68: """ jpayne@68: Monkey patching of distutils. jpayne@68: """ jpayne@68: jpayne@68: from __future__ import annotations jpayne@68: jpayne@68: import inspect jpayne@68: import platform jpayne@68: import sys jpayne@68: import types jpayne@68: from typing import Type, TypeVar, cast, overload jpayne@68: jpayne@68: import distutils.filelist jpayne@68: jpayne@68: _T = TypeVar("_T") jpayne@68: _UnpatchT = TypeVar("_UnpatchT", type, types.FunctionType) jpayne@68: jpayne@68: jpayne@68: __all__: list[str] = [] jpayne@68: """ jpayne@68: Everything is private. Contact the project team jpayne@68: if you think you need this functionality. jpayne@68: """ jpayne@68: jpayne@68: jpayne@68: def _get_mro(cls): jpayne@68: """ jpayne@68: Returns the bases classes for cls sorted by the MRO. jpayne@68: jpayne@68: Works around an issue on Jython where inspect.getmro will not return all jpayne@68: base classes if multiple classes share the same name. Instead, this jpayne@68: function will return a tuple containing the class itself, and the contents jpayne@68: of cls.__bases__. See https://github.com/pypa/setuptools/issues/1024. jpayne@68: """ jpayne@68: if platform.python_implementation() == "Jython": jpayne@68: return (cls,) + cls.__bases__ jpayne@68: return inspect.getmro(cls) jpayne@68: jpayne@68: jpayne@68: @overload jpayne@68: def get_unpatched(item: _UnpatchT) -> _UnpatchT: ... jpayne@68: @overload jpayne@68: def get_unpatched(item: object) -> None: ... jpayne@68: def get_unpatched( jpayne@68: item: type | types.FunctionType | object, jpayne@68: ) -> type | types.FunctionType | None: jpayne@68: if isinstance(item, type): jpayne@68: return get_unpatched_class(item) jpayne@68: if isinstance(item, types.FunctionType): jpayne@68: return get_unpatched_function(item) jpayne@68: return None jpayne@68: jpayne@68: jpayne@68: def get_unpatched_class(cls: type[_T]) -> type[_T]: jpayne@68: """Protect against re-patching the distutils if reloaded jpayne@68: jpayne@68: Also ensures that no other distutils extension monkeypatched the distutils jpayne@68: first. jpayne@68: """ jpayne@68: external_bases = ( jpayne@68: cast(Type[_T], cls) jpayne@68: for cls in _get_mro(cls) jpayne@68: if not cls.__module__.startswith('setuptools') jpayne@68: ) jpayne@68: base = next(external_bases) jpayne@68: if not base.__module__.startswith('distutils'): jpayne@68: msg = "distutils has already been patched by %r" % cls jpayne@68: raise AssertionError(msg) jpayne@68: return base jpayne@68: jpayne@68: jpayne@68: def patch_all(): jpayne@68: import setuptools jpayne@68: jpayne@68: # we can't patch distutils.cmd, alas jpayne@68: distutils.core.Command = setuptools.Command jpayne@68: jpayne@68: _patch_distribution_metadata() jpayne@68: jpayne@68: # Install Distribution throughout the distutils jpayne@68: for module in distutils.dist, distutils.core, distutils.cmd: jpayne@68: module.Distribution = setuptools.dist.Distribution jpayne@68: jpayne@68: # Install the patched Extension jpayne@68: distutils.core.Extension = setuptools.extension.Extension jpayne@68: distutils.extension.Extension = setuptools.extension.Extension jpayne@68: if 'distutils.command.build_ext' in sys.modules: jpayne@68: sys.modules[ jpayne@68: 'distutils.command.build_ext' jpayne@68: ].Extension = setuptools.extension.Extension jpayne@68: jpayne@68: jpayne@68: def _patch_distribution_metadata(): jpayne@68: from . import _core_metadata jpayne@68: jpayne@68: """Patch write_pkg_file and read_pkg_file for higher metadata standards""" jpayne@68: for attr in ( jpayne@68: 'write_pkg_info', jpayne@68: 'write_pkg_file', jpayne@68: 'read_pkg_file', jpayne@68: 'get_metadata_version', jpayne@68: 'get_fullname', jpayne@68: ): jpayne@68: new_val = getattr(_core_metadata, attr) jpayne@68: setattr(distutils.dist.DistributionMetadata, attr, new_val) jpayne@68: jpayne@68: jpayne@68: def patch_func(replacement, target_mod, func_name): jpayne@68: """ jpayne@68: Patch func_name in target_mod with replacement jpayne@68: jpayne@68: Important - original must be resolved by name to avoid jpayne@68: patching an already patched function. jpayne@68: """ jpayne@68: original = getattr(target_mod, func_name) jpayne@68: jpayne@68: # set the 'unpatched' attribute on the replacement to jpayne@68: # point to the original. jpayne@68: vars(replacement).setdefault('unpatched', original) jpayne@68: jpayne@68: # replace the function in the original module jpayne@68: setattr(target_mod, func_name, replacement) jpayne@68: jpayne@68: jpayne@68: def get_unpatched_function(candidate): jpayne@68: return candidate.unpatched