jpayne@68: from __future__ import annotations jpayne@68: jpayne@68: import csv jpayne@68: import hashlib jpayne@68: import os.path jpayne@68: import re jpayne@68: import stat jpayne@68: import time jpayne@68: from io import StringIO, TextIOWrapper jpayne@68: from typing import IO, TYPE_CHECKING, Literal jpayne@68: from zipfile import ZIP_DEFLATED, ZipFile, ZipInfo jpayne@68: jpayne@68: from wheel.cli import WheelError jpayne@68: from wheel.util import log, urlsafe_b64decode, urlsafe_b64encode jpayne@68: jpayne@68: if TYPE_CHECKING: jpayne@68: from typing import Protocol, Sized, Union jpayne@68: jpayne@68: from typing_extensions import Buffer jpayne@68: jpayne@68: StrPath = Union[str, os.PathLike[str]] jpayne@68: jpayne@68: class SizedBuffer(Sized, Buffer, Protocol): ... jpayne@68: jpayne@68: jpayne@68: # Non-greedy matching of an optional build number may be too clever (more jpayne@68: # invalid wheel filenames will match). Separate regex for .dist-info? jpayne@68: WHEEL_INFO_RE = re.compile( jpayne@68: r"""^(?P(?P[^\s-]+?)-(?P[^\s-]+?))(-(?P\d[^\s-]*))? jpayne@68: -(?P[^\s-]+?)-(?P[^\s-]+?)-(?P\S+)\.whl$""", jpayne@68: re.VERBOSE, jpayne@68: ) jpayne@68: MINIMUM_TIMESTAMP = 315532800 # 1980-01-01 00:00:00 UTC jpayne@68: jpayne@68: jpayne@68: def get_zipinfo_datetime(timestamp: float | None = None): jpayne@68: # Some applications need reproducible .whl files, but they can't do this without jpayne@68: # forcing the timestamp of the individual ZipInfo objects. See issue #143. jpayne@68: timestamp = int(os.environ.get("SOURCE_DATE_EPOCH", timestamp or time.time())) jpayne@68: timestamp = max(timestamp, MINIMUM_TIMESTAMP) jpayne@68: return time.gmtime(timestamp)[0:6] jpayne@68: jpayne@68: jpayne@68: class WheelFile(ZipFile): jpayne@68: """A ZipFile derivative class that also reads SHA-256 hashes from jpayne@68: .dist-info/RECORD and checks any read files against those. jpayne@68: """ jpayne@68: jpayne@68: _default_algorithm = hashlib.sha256 jpayne@68: jpayne@68: def __init__( jpayne@68: self, jpayne@68: file: StrPath, jpayne@68: mode: Literal["r", "w", "x", "a"] = "r", jpayne@68: compression: int = ZIP_DEFLATED, jpayne@68: ): jpayne@68: basename = os.path.basename(file) jpayne@68: self.parsed_filename = WHEEL_INFO_RE.match(basename) jpayne@68: if not basename.endswith(".whl") or self.parsed_filename is None: jpayne@68: raise WheelError(f"Bad wheel filename {basename!r}") jpayne@68: jpayne@68: ZipFile.__init__(self, file, mode, compression=compression, allowZip64=True) jpayne@68: jpayne@68: self.dist_info_path = "{}.dist-info".format( jpayne@68: self.parsed_filename.group("namever") jpayne@68: ) jpayne@68: self.record_path = self.dist_info_path + "/RECORD" jpayne@68: self._file_hashes: dict[str, tuple[None, None] | tuple[int, bytes]] = {} jpayne@68: self._file_sizes = {} jpayne@68: if mode == "r": jpayne@68: # Ignore RECORD and any embedded wheel signatures jpayne@68: self._file_hashes[self.record_path] = None, None jpayne@68: self._file_hashes[self.record_path + ".jws"] = None, None jpayne@68: self._file_hashes[self.record_path + ".p7s"] = None, None jpayne@68: jpayne@68: # Fill in the expected hashes by reading them from RECORD jpayne@68: try: jpayne@68: record = self.open(self.record_path) jpayne@68: except KeyError: jpayne@68: raise WheelError(f"Missing {self.record_path} file") from None jpayne@68: jpayne@68: with record: jpayne@68: for line in csv.reader( jpayne@68: TextIOWrapper(record, newline="", encoding="utf-8") jpayne@68: ): jpayne@68: path, hash_sum, size = line jpayne@68: if not hash_sum: jpayne@68: continue jpayne@68: jpayne@68: algorithm, hash_sum = hash_sum.split("=") jpayne@68: try: jpayne@68: hashlib.new(algorithm) jpayne@68: except ValueError: jpayne@68: raise WheelError( jpayne@68: f"Unsupported hash algorithm: {algorithm}" jpayne@68: ) from None jpayne@68: jpayne@68: if algorithm.lower() in {"md5", "sha1"}: jpayne@68: raise WheelError( jpayne@68: f"Weak hash algorithm ({algorithm}) is not permitted by " jpayne@68: f"PEP 427" jpayne@68: ) jpayne@68: jpayne@68: self._file_hashes[path] = ( jpayne@68: algorithm, jpayne@68: urlsafe_b64decode(hash_sum.encode("ascii")), jpayne@68: ) jpayne@68: jpayne@68: def open( jpayne@68: self, jpayne@68: name_or_info: str | ZipInfo, jpayne@68: mode: Literal["r", "w"] = "r", jpayne@68: pwd: bytes | None = None, jpayne@68: ) -> IO[bytes]: jpayne@68: def _update_crc(newdata: bytes) -> None: jpayne@68: eof = ef._eof jpayne@68: update_crc_orig(newdata) jpayne@68: running_hash.update(newdata) jpayne@68: if eof and running_hash.digest() != expected_hash: jpayne@68: raise WheelError(f"Hash mismatch for file '{ef_name}'") jpayne@68: jpayne@68: ef_name = ( jpayne@68: name_or_info.filename if isinstance(name_or_info, ZipInfo) else name_or_info jpayne@68: ) jpayne@68: if ( jpayne@68: mode == "r" jpayne@68: and not ef_name.endswith("/") jpayne@68: and ef_name not in self._file_hashes jpayne@68: ): jpayne@68: raise WheelError(f"No hash found for file '{ef_name}'") jpayne@68: jpayne@68: ef = ZipFile.open(self, name_or_info, mode, pwd) jpayne@68: if mode == "r" and not ef_name.endswith("/"): jpayne@68: algorithm, expected_hash = self._file_hashes[ef_name] jpayne@68: if expected_hash is not None: jpayne@68: # Monkey patch the _update_crc method to also check for the hash from jpayne@68: # RECORD jpayne@68: running_hash = hashlib.new(algorithm) jpayne@68: update_crc_orig, ef._update_crc = ef._update_crc, _update_crc jpayne@68: jpayne@68: return ef jpayne@68: jpayne@68: def write_files(self, base_dir: str): jpayne@68: log.info(f"creating '{self.filename}' and adding '{base_dir}' to it") jpayne@68: deferred: list[tuple[str, str]] = [] jpayne@68: for root, dirnames, filenames in os.walk(base_dir): jpayne@68: # Sort the directory names so that `os.walk` will walk them in a jpayne@68: # defined order on the next iteration. jpayne@68: dirnames.sort() jpayne@68: for name in sorted(filenames): jpayne@68: path = os.path.normpath(os.path.join(root, name)) jpayne@68: if os.path.isfile(path): jpayne@68: arcname = os.path.relpath(path, base_dir).replace(os.path.sep, "/") jpayne@68: if arcname == self.record_path: jpayne@68: pass jpayne@68: elif root.endswith(".dist-info"): jpayne@68: deferred.append((path, arcname)) jpayne@68: else: jpayne@68: self.write(path, arcname) jpayne@68: jpayne@68: deferred.sort() jpayne@68: for path, arcname in deferred: jpayne@68: self.write(path, arcname) jpayne@68: jpayne@68: def write( jpayne@68: self, jpayne@68: filename: str, jpayne@68: arcname: str | None = None, jpayne@68: compress_type: int | None = None, jpayne@68: ) -> None: jpayne@68: with open(filename, "rb") as f: jpayne@68: st = os.fstat(f.fileno()) jpayne@68: data = f.read() jpayne@68: jpayne@68: zinfo = ZipInfo( jpayne@68: arcname or filename, date_time=get_zipinfo_datetime(st.st_mtime) jpayne@68: ) jpayne@68: zinfo.external_attr = (stat.S_IMODE(st.st_mode) | stat.S_IFMT(st.st_mode)) << 16 jpayne@68: zinfo.compress_type = compress_type or self.compression jpayne@68: self.writestr(zinfo, data, compress_type) jpayne@68: jpayne@68: def writestr( jpayne@68: self, jpayne@68: zinfo_or_arcname: str | ZipInfo, jpayne@68: data: SizedBuffer | str, jpayne@68: compress_type: int | None = None, jpayne@68: ): jpayne@68: if isinstance(zinfo_or_arcname, str): jpayne@68: zinfo_or_arcname = ZipInfo( jpayne@68: zinfo_or_arcname, date_time=get_zipinfo_datetime() jpayne@68: ) jpayne@68: zinfo_or_arcname.compress_type = self.compression jpayne@68: zinfo_or_arcname.external_attr = (0o664 | stat.S_IFREG) << 16 jpayne@68: jpayne@68: if isinstance(data, str): jpayne@68: data = data.encode("utf-8") jpayne@68: jpayne@68: ZipFile.writestr(self, zinfo_or_arcname, data, compress_type) jpayne@68: fname = ( jpayne@68: zinfo_or_arcname.filename jpayne@68: if isinstance(zinfo_or_arcname, ZipInfo) jpayne@68: else zinfo_or_arcname jpayne@68: ) jpayne@68: log.info(f"adding '{fname}'") jpayne@68: if fname != self.record_path: jpayne@68: hash_ = self._default_algorithm(data) jpayne@68: self._file_hashes[fname] = ( jpayne@68: hash_.name, jpayne@68: urlsafe_b64encode(hash_.digest()).decode("ascii"), jpayne@68: ) jpayne@68: self._file_sizes[fname] = len(data) jpayne@68: jpayne@68: def close(self): jpayne@68: # Write RECORD jpayne@68: if self.fp is not None and self.mode == "w" and self._file_hashes: jpayne@68: data = StringIO() jpayne@68: writer = csv.writer(data, delimiter=",", quotechar='"', lineterminator="\n") jpayne@68: writer.writerows( jpayne@68: ( jpayne@68: (fname, algorithm + "=" + hash_, self._file_sizes[fname]) jpayne@68: for fname, (algorithm, hash_) in self._file_hashes.items() jpayne@68: ) jpayne@68: ) jpayne@68: writer.writerow((format(self.record_path), "", "")) jpayne@68: self.writestr(self.record_path, data.getvalue()) jpayne@68: jpayne@68: ZipFile.close(self)