jpayne@68
|
1 """
|
jpayne@68
|
2 Monkey patching of distutils.
|
jpayne@68
|
3 """
|
jpayne@68
|
4
|
jpayne@68
|
5 from __future__ import annotations
|
jpayne@68
|
6
|
jpayne@68
|
7 import inspect
|
jpayne@68
|
8 import platform
|
jpayne@68
|
9 import sys
|
jpayne@68
|
10 import types
|
jpayne@68
|
11 from typing import Type, TypeVar, cast, overload
|
jpayne@68
|
12
|
jpayne@68
|
13 import distutils.filelist
|
jpayne@68
|
14
|
jpayne@68
|
15 _T = TypeVar("_T")
|
jpayne@68
|
16 _UnpatchT = TypeVar("_UnpatchT", type, types.FunctionType)
|
jpayne@68
|
17
|
jpayne@68
|
18
|
jpayne@68
|
19 __all__: list[str] = []
|
jpayne@68
|
20 """
|
jpayne@68
|
21 Everything is private. Contact the project team
|
jpayne@68
|
22 if you think you need this functionality.
|
jpayne@68
|
23 """
|
jpayne@68
|
24
|
jpayne@68
|
25
|
jpayne@68
|
26 def _get_mro(cls):
|
jpayne@68
|
27 """
|
jpayne@68
|
28 Returns the bases classes for cls sorted by the MRO.
|
jpayne@68
|
29
|
jpayne@68
|
30 Works around an issue on Jython where inspect.getmro will not return all
|
jpayne@68
|
31 base classes if multiple classes share the same name. Instead, this
|
jpayne@68
|
32 function will return a tuple containing the class itself, and the contents
|
jpayne@68
|
33 of cls.__bases__. See https://github.com/pypa/setuptools/issues/1024.
|
jpayne@68
|
34 """
|
jpayne@68
|
35 if platform.python_implementation() == "Jython":
|
jpayne@68
|
36 return (cls,) + cls.__bases__
|
jpayne@68
|
37 return inspect.getmro(cls)
|
jpayne@68
|
38
|
jpayne@68
|
39
|
jpayne@68
|
40 @overload
|
jpayne@68
|
41 def get_unpatched(item: _UnpatchT) -> _UnpatchT: ...
|
jpayne@68
|
42 @overload
|
jpayne@68
|
43 def get_unpatched(item: object) -> None: ...
|
jpayne@68
|
44 def get_unpatched(
|
jpayne@68
|
45 item: type | types.FunctionType | object,
|
jpayne@68
|
46 ) -> type | types.FunctionType | None:
|
jpayne@68
|
47 if isinstance(item, type):
|
jpayne@68
|
48 return get_unpatched_class(item)
|
jpayne@68
|
49 if isinstance(item, types.FunctionType):
|
jpayne@68
|
50 return get_unpatched_function(item)
|
jpayne@68
|
51 return None
|
jpayne@68
|
52
|
jpayne@68
|
53
|
jpayne@68
|
54 def get_unpatched_class(cls: type[_T]) -> type[_T]:
|
jpayne@68
|
55 """Protect against re-patching the distutils if reloaded
|
jpayne@68
|
56
|
jpayne@68
|
57 Also ensures that no other distutils extension monkeypatched the distutils
|
jpayne@68
|
58 first.
|
jpayne@68
|
59 """
|
jpayne@68
|
60 external_bases = (
|
jpayne@68
|
61 cast(Type[_T], cls)
|
jpayne@68
|
62 for cls in _get_mro(cls)
|
jpayne@68
|
63 if not cls.__module__.startswith('setuptools')
|
jpayne@68
|
64 )
|
jpayne@68
|
65 base = next(external_bases)
|
jpayne@68
|
66 if not base.__module__.startswith('distutils'):
|
jpayne@68
|
67 msg = "distutils has already been patched by %r" % cls
|
jpayne@68
|
68 raise AssertionError(msg)
|
jpayne@68
|
69 return base
|
jpayne@68
|
70
|
jpayne@68
|
71
|
jpayne@68
|
72 def patch_all():
|
jpayne@68
|
73 import setuptools
|
jpayne@68
|
74
|
jpayne@68
|
75 # we can't patch distutils.cmd, alas
|
jpayne@68
|
76 distutils.core.Command = setuptools.Command
|
jpayne@68
|
77
|
jpayne@68
|
78 _patch_distribution_metadata()
|
jpayne@68
|
79
|
jpayne@68
|
80 # Install Distribution throughout the distutils
|
jpayne@68
|
81 for module in distutils.dist, distutils.core, distutils.cmd:
|
jpayne@68
|
82 module.Distribution = setuptools.dist.Distribution
|
jpayne@68
|
83
|
jpayne@68
|
84 # Install the patched Extension
|
jpayne@68
|
85 distutils.core.Extension = setuptools.extension.Extension
|
jpayne@68
|
86 distutils.extension.Extension = setuptools.extension.Extension
|
jpayne@68
|
87 if 'distutils.command.build_ext' in sys.modules:
|
jpayne@68
|
88 sys.modules[
|
jpayne@68
|
89 'distutils.command.build_ext'
|
jpayne@68
|
90 ].Extension = setuptools.extension.Extension
|
jpayne@68
|
91
|
jpayne@68
|
92
|
jpayne@68
|
93 def _patch_distribution_metadata():
|
jpayne@68
|
94 from . import _core_metadata
|
jpayne@68
|
95
|
jpayne@68
|
96 """Patch write_pkg_file and read_pkg_file for higher metadata standards"""
|
jpayne@68
|
97 for attr in (
|
jpayne@68
|
98 'write_pkg_info',
|
jpayne@68
|
99 'write_pkg_file',
|
jpayne@68
|
100 'read_pkg_file',
|
jpayne@68
|
101 'get_metadata_version',
|
jpayne@68
|
102 'get_fullname',
|
jpayne@68
|
103 ):
|
jpayne@68
|
104 new_val = getattr(_core_metadata, attr)
|
jpayne@68
|
105 setattr(distutils.dist.DistributionMetadata, attr, new_val)
|
jpayne@68
|
106
|
jpayne@68
|
107
|
jpayne@68
|
108 def patch_func(replacement, target_mod, func_name):
|
jpayne@68
|
109 """
|
jpayne@68
|
110 Patch func_name in target_mod with replacement
|
jpayne@68
|
111
|
jpayne@68
|
112 Important - original must be resolved by name to avoid
|
jpayne@68
|
113 patching an already patched function.
|
jpayne@68
|
114 """
|
jpayne@68
|
115 original = getattr(target_mod, func_name)
|
jpayne@68
|
116
|
jpayne@68
|
117 # set the 'unpatched' attribute on the replacement to
|
jpayne@68
|
118 # point to the original.
|
jpayne@68
|
119 vars(replacement).setdefault('unpatched', original)
|
jpayne@68
|
120
|
jpayne@68
|
121 # replace the function in the original module
|
jpayne@68
|
122 setattr(target_mod, func_name, replacement)
|
jpayne@68
|
123
|
jpayne@68
|
124
|
jpayne@68
|
125 def get_unpatched_function(candidate):
|
jpayne@68
|
126 return candidate.unpatched
|