jpayne@0: from bz2 import open as bzopen jpayne@0: from gzip import open as gzopen jpayne@0: jpayne@0: from contextlib import ExitStack jpayne@0: from itertools import zip_longest jpayne@0: from pathlib import Path jpayne@0: from sys import argv jpayne@0: jpayne@0: import random jpayne@0: jpayne@0: jpayne@0: usage = """ jpayne@0: jpayne@0: """ jpayne@0: jpayne@0: def grouper(iterable, n, fillvalue=None): jpayne@0: "Collect data into fixed-length chunks or blocks" jpayne@0: # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" jpayne@0: args = [iter(iterable)] * n jpayne@0: return zip_longest(*args, fillvalue=fillvalue) jpayne@0: jpayne@0: # file compression signatures jpayne@0: magics = { jpayne@0: b'\x1f\x8b\x08':gzopen, jpayne@0: b'\x42\x5a\x68':bzopen, jpayne@0: } jpayne@0: jpayne@0: def sniff(path): jpayne@0: "Sniff first three bytes of the file to determine format based on the magic number." jpayne@0: with open(path, 'rb') as fp: jpayne@0: magic = fp.read(3) jpayne@0: return magics.get(magic, open) jpayne@0: jpayne@0: jpayne@0: def coverage(collection, genome_size): jpayne@0: "Collection of 1 or 2 tuples, whose 2nd item is the read string" jpayne@0: return sum((len(read[0][1]) for read in collection)) / genome_size # reverse read pair doesn't contribute to coverage so we can ignore it jpayne@0: jpayne@0: jpayne@0: try: jpayne@0: fin, rin, fout, rout, cov, gen_size, *opts = argv[1:] jpayne@0: ins = [fin, rin] jpayne@0: outs = [fout, rout] jpayne@0: except ValueError: # not enough values to unpack jpayne@0: try: jpayne@5: fin, fout, cov, gen_size, *opts = argv[1:] jpayne@0: ins = [fin] jpayne@0: outs = [fout] jpayne@0: except ValueError: jpayne@0: print(usage) jpayne@0: quit(1) jpayne@0: try: jpayne@0: cov = float(cov) jpayne@0: gen_size = int(gen_size) jpayne@0: except ValueError: jpayne@0: print("Desired coverage and assumed genome size should be numbers") jpayne@0: print(usage) jpayne@0: quit(1) jpayne@0: jpayne@0: seed = "ed2b99d842cddc1ac81d7c01a0bf0555" jpayne@0: if opts: jpayne@0: seed = opts[0] jpayne@0: random.seed(seed) jpayne@0: jpayne@0: assert len(ins) == len(outs) jpayne@0: file_openers = [sniff(path) for path in ins] # output format determined by input format jpayne@0: with ExitStack() as stack: jpayne@0: ins = [stack.enter_context(openn(path, 'r')) for openn, path in zip(file_openers, ins)] # opened input files jpayne@0: inns = [iter(grouper(inn, 4)) for inn in ins] # stateful 4-ply iterator over lines in the input jpayne@0: outs = [stack.enter_context(openn(path, 'w')) for openn, path in zip(file_openers, outs)] # opened output files jpayne@0: jpayne@4: for file in ins: jpayne@5: if hasattr(file, "name"): jpayne@5: print(file.name) jpayne@4: jpayne@0: # https://en.m.wikipedia.org/wiki/Reservoir_sampling jpayne@0: jpayne@0: reservoir = [] jpayne@0: # this is going to be 1 or 2-tuples of 4-tuples representing the 4 lines of the fastq file jpayne@0: # we determine its current coverage (and thus its reservoir size) to fill it, which consumes reads jpayne@0: # from the open files jpayne@4: reads = 0 jpayne@4: for i, readpair in enumerate(zip(*inns)): jpayne@4: reads += len(readpair[0][1]) jpayne@0: reservoir.append(readpair) jpayne@4: if reads / gen_size > cov: jpayne@0: break jpayne@0: jpayne@0: k = len(reservoir) # this is about how big the reservoir needs to be to get cov coverage jpayne@0: #W = exp(log(random.random()) / k) jpayne@0: jpayne@0: random.shuffle(reservoir) jpayne@0: jpayne@0: print(f"{k} reads selected to achieve {coverage(reservoir, gen_size):.3f}X coverage.") jpayne@0: jpayne@0: # if the number of reads is too few to meet the coverage cutoff, then the iterators jpayne@0: # should be exhausted and this won't run jpayne@0: # this is essentially Algorithm L, as I understand it jpayne@0: for i, readpair in enumerate(zip(*inns)): jpayne@0: r = random.randint(0, i) jpayne@0: if r < k: jpayne@0: reservoir[r] = readpair jpayne@0: jpayne@0: for readpair in reservoir: # output the sampled reads jpayne@0: for read, file in zip(readpair, outs): jpayne@0: defline, read, spacer, quals = read jpayne@0: file.write(defline) jpayne@0: file.write(read) jpayne@0: file.write(spacer) jpayne@0: file.write(quals) jpayne@0: jpayne@0: # [fp.close() for fp in ins] jpayne@0: # [fp.close() for fp in outs]