jpayne@69: """ jpayne@69: Monkey patching of distutils. jpayne@69: """ jpayne@69: jpayne@69: from __future__ import annotations jpayne@69: jpayne@69: import inspect jpayne@69: import platform jpayne@69: import sys jpayne@69: import types jpayne@69: from typing import Type, TypeVar, cast, overload jpayne@69: jpayne@69: import distutils.filelist jpayne@69: jpayne@69: _T = TypeVar("_T") jpayne@69: _UnpatchT = TypeVar("_UnpatchT", type, types.FunctionType) jpayne@69: jpayne@69: jpayne@69: __all__: list[str] = [] jpayne@69: """ jpayne@69: Everything is private. Contact the project team jpayne@69: if you think you need this functionality. jpayne@69: """ jpayne@69: jpayne@69: jpayne@69: def _get_mro(cls): jpayne@69: """ jpayne@69: Returns the bases classes for cls sorted by the MRO. jpayne@69: jpayne@69: Works around an issue on Jython where inspect.getmro will not return all jpayne@69: base classes if multiple classes share the same name. Instead, this jpayne@69: function will return a tuple containing the class itself, and the contents jpayne@69: of cls.__bases__. See https://github.com/pypa/setuptools/issues/1024. jpayne@69: """ jpayne@69: if platform.python_implementation() == "Jython": jpayne@69: return (cls,) + cls.__bases__ jpayne@69: return inspect.getmro(cls) jpayne@69: jpayne@69: jpayne@69: @overload jpayne@69: def get_unpatched(item: _UnpatchT) -> _UnpatchT: ... jpayne@69: @overload jpayne@69: def get_unpatched(item: object) -> None: ... jpayne@69: def get_unpatched( jpayne@69: item: type | types.FunctionType | object, jpayne@69: ) -> type | types.FunctionType | None: jpayne@69: if isinstance(item, type): jpayne@69: return get_unpatched_class(item) jpayne@69: if isinstance(item, types.FunctionType): jpayne@69: return get_unpatched_function(item) jpayne@69: return None jpayne@69: jpayne@69: jpayne@69: def get_unpatched_class(cls: type[_T]) -> type[_T]: jpayne@69: """Protect against re-patching the distutils if reloaded jpayne@69: jpayne@69: Also ensures that no other distutils extension monkeypatched the distutils jpayne@69: first. jpayne@69: """ jpayne@69: external_bases = ( jpayne@69: cast(Type[_T], cls) jpayne@69: for cls in _get_mro(cls) jpayne@69: if not cls.__module__.startswith('setuptools') jpayne@69: ) jpayne@69: base = next(external_bases) jpayne@69: if not base.__module__.startswith('distutils'): jpayne@69: msg = "distutils has already been patched by %r" % cls jpayne@69: raise AssertionError(msg) jpayne@69: return base jpayne@69: jpayne@69: jpayne@69: def patch_all(): jpayne@69: import setuptools jpayne@69: jpayne@69: # we can't patch distutils.cmd, alas jpayne@69: distutils.core.Command = setuptools.Command jpayne@69: jpayne@69: _patch_distribution_metadata() jpayne@69: jpayne@69: # Install Distribution throughout the distutils jpayne@69: for module in distutils.dist, distutils.core, distutils.cmd: jpayne@69: module.Distribution = setuptools.dist.Distribution jpayne@69: jpayne@69: # Install the patched Extension jpayne@69: distutils.core.Extension = setuptools.extension.Extension jpayne@69: distutils.extension.Extension = setuptools.extension.Extension jpayne@69: if 'distutils.command.build_ext' in sys.modules: jpayne@69: sys.modules[ jpayne@69: 'distutils.command.build_ext' jpayne@69: ].Extension = setuptools.extension.Extension jpayne@69: jpayne@69: jpayne@69: def _patch_distribution_metadata(): jpayne@69: from . import _core_metadata jpayne@69: jpayne@69: """Patch write_pkg_file and read_pkg_file for higher metadata standards""" jpayne@69: for attr in ( jpayne@69: 'write_pkg_info', jpayne@69: 'write_pkg_file', jpayne@69: 'read_pkg_file', jpayne@69: 'get_metadata_version', jpayne@69: 'get_fullname', jpayne@69: ): jpayne@69: new_val = getattr(_core_metadata, attr) jpayne@69: setattr(distutils.dist.DistributionMetadata, attr, new_val) jpayne@69: jpayne@69: jpayne@69: def patch_func(replacement, target_mod, func_name): jpayne@69: """ jpayne@69: Patch func_name in target_mod with replacement jpayne@69: jpayne@69: Important - original must be resolved by name to avoid jpayne@69: patching an already patched function. jpayne@69: """ jpayne@69: original = getattr(target_mod, func_name) jpayne@69: jpayne@69: # set the 'unpatched' attribute on the replacement to jpayne@69: # point to the original. jpayne@69: vars(replacement).setdefault('unpatched', original) jpayne@69: jpayne@69: # replace the function in the original module jpayne@69: setattr(target_mod, func_name, replacement) jpayne@69: jpayne@69: jpayne@69: def get_unpatched_function(candidate): jpayne@69: return candidate.unpatched