jpayne@4: "Grab SRR numbers from Bioprojects and sub-bioprojects via Eutils" jpayne@4: jpayne@4: import requests jpayne@4: import sys jpayne@4: import csv jpayne@11: import os jpayne@4: jpayne@8: try: jpayne@8: from itertools import batched jpayne@8: except ImportError: jpayne@9: from itertools import islice jpayne@8: def batched(iterable, n): jpayne@8: "Batch data into tuples of length n. The last batch may be shorter." jpayne@8: # batched('ABCDEFG', 3) --> ABC DEF G jpayne@8: if n < 1: jpayne@8: raise ValueError('n must be at least one') jpayne@8: it = iter(iterable) jpayne@8: while batch := tuple(islice(it, n)): jpayne@8: yield batch jpayne@4: from functools import cmp_to_key jpayne@4: from time import sleep jpayne@4: from xml.etree import ElementTree as xml jpayne@4: jpayne@4: esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" jpayne@4: esummary = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi" jpayne@4: elink = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/elink.fcgi" jpayne@4: jpayne@4: jpayne@4: import logging jpayne@4: logging.basicConfig(level=logging.INFO) jpayne@4: jpayne@4: logger = logging.getLogger("bio2srr") jpayne@4: jpayne@11: extra_params = {} jpayne@11: jpayne@11: api_key = os.environ.get("NCBI_API_KEY") jpayne@11: jpayne@11: if api_key: jpayne@11: logger.info(f"Using NCBI API key {api_key[:4]}{'*' * (len(api_key) - 8)}{api_key[-4:]}") jpayne@11: extra_params["api_key"] = api_key jpayne@11: jpayne@4: def log(msg): jpayne@11: if api_key: jpayne@11: logger.info(msg.replace(api_key, f"{api_key[:4]}{'*' * (len(api_key) - 8)}{api_key[-4:]}")) # fix logging later jpayne@11: else: jpayne@11: logger.info(msg) jpayne@4: jpayne@4: def get_tag(root, tag): jpayne@4: val = root.find(tag) jpayne@4: if val is not None: jpayne@4: return val.text jpayne@4: log(f"No result for {tag}") jpayne@4: jpayne@4: jpayne@4: jpayne@4: def header_sort_override(a, b): jpayne@4: if a == b: jpayne@4: return 0 jpayne@4: try: jpayne@4: for name in ["bioproject", "srr_accession", "biosample_accession", "organism", "taxid", "package",]: jpayne@4: if a == name: jpayne@4: return -1 jpayne@4: if b == name: jpayne@4: return 1 jpayne@4: except: jpayne@4: pass jpayne@4: if a < b: jpayne@4: return -1 jpayne@4: else: jpayne@4: return 1 jpayne@4: jpayne@4: hso = cmp_to_key(header_sort_override) jpayne@4: jpayne@4: def resolve_bioproject_ids_and_links(bioproject_id_list): jpayne@4: "Recursively follow bioproject and biosample links, yield biosample UID's and biosample XML" jpayne@4: for i, (bioproject, bioproject_id) in enumerate(bioproject_id_list): jpayne@4: log(f"Processing {bioproject} ({bioproject_id}) {i+1}/{len(bioproject_id_list)}") jpayne@4: #get bioproject to bioproject links jpayne@11: response = requests.get(elink, params=dict(db="bioproject", dbfrom="bioproject", id=bioproject_id, format="json", **extra_params)) jpayne@4: response.raise_for_status() jpayne@4: reply = response.json() jpayne@4: linksets = reply.get("linksets", [{}])[0].get("linksetdbs", [0,0,{}]) jpayne@4: if len(linksets) >= 3: jpayne@4: for id in linksets[2].get("links", []): #third index is the up to down links jpayne@4: response = requests.get(esummary, params=dict(id=id, db="bioproject", format="json")) jpayne@4: response.raise_for_status() jpayne@4: replyy = response.json() jpayne@4: biop = replyy["result"][id]["project_acc"] jpayne@4: if id not in bioproject_id_list: jpayne@4: bioproject_id_list.append((biop, id)) # recurse over bioproject links jpayne@4: # get bioproject to biosample links jpayne@11: response = requests.get(elink, params=dict(db="biosample", dbfrom="bioproject", id=bioproject_id, format="json", **extra_params)) jpayne@4: response.raise_for_status() jpayne@4: reply = response.json() jpayne@4: links = reply.get("linksets", [{}])[0].get("linksetdbs", [{}])[0].get("links", []) jpayne@4: log(f"Found {len(links)} biosample links for {bioproject} ({bioproject_id})") jpayne@4: for ids in batched(links, 200): jpayne@4: response = requests.get(esummary, params=dict(id=",".join(ids), db="biosample", format="json")) jpayne@4: response.raise_for_status() jpayne@4: replyy = response.json() jpayne@4: for field, value in replyy.get("result", {}).items(): jpayne@4: if "uids" not in field: jpayne@4: yield bioproject, field, value["sampledata"] # this is XML, deleriously jpayne@11: sleep(1 if not api_key else 0.1) jpayne@4: jpayne@4: jpayne@4: biosample_example = """ jpayne@4: jpayne@4: jpayne@4: SAMN17131268 jpayne@4: CJP19-D996 jpayne@4: jpayne@4: jpayne@4: Pathogen: environmental/food/other sample from Campylobacter jejuni jpayne@4: jpayne@4: Campylobacter jejuni jpayne@4: jpayne@4: jpayne@4: jpayne@4: FDA Center for Food Safety and Applied Nutrition jpayne@4: jpayne@4: jpayne@4: Pathogen.env jpayne@4: jpayne@4: Pathogen.env.1.0 jpayne@4: jpayne@4: CJP19-D996 jpayne@4: missing jpayne@4: missing jpayne@4: CDC jpayne@4: missing jpayne@4: missing jpayne@4: CFSAN091032 jpayne@4: GenomeTrakr jpayne@4: FDA Center for Food Safety and Applied Nutrition jpayne@4: jpayne@4: jpayne@4: 681235 jpayne@4: jpayne@4: jpayne@4: jpayne@4: jpayne@4: """ jpayne@4: jpayne@4: def flatten_biosample_xml(biosampxml): jpayne@4: root = xml.fromstring(biosampxml) jpayne@4: accession = get_tag(root, r'.//Id[@db="BioSample"]') jpayne@4: # sample_name = get_tag(root, r'.//Id[@db_label="Sample name"]') jpayne@4: organism = get_tag(root, r".//OrganismName") jpayne@4: tax_id = root.find(r".//Organism").attrib.get("taxonomy_id") jpayne@4: package = get_tag(root, r".//Package") jpayne@4: sampledict = dict( jpayne@4: biosample_accession=accession, jpayne@4: # sample_name=sample_name, jpayne@4: organism = organism, jpayne@4: taxid = tax_id, jpayne@4: package = package jpayne@4: ) jpayne@4: for attribute in root.findall("Attributes/Attribute"): jpayne@4: sampledict[attribute.attrib.get("harmonized_name", attribute.attrib['attribute_name'])] = attribute.text jpayne@4: jpayne@4: return sampledict jpayne@4: jpayne@4: jpayne@12: def yield_sra_runs_from_sample(biosample): jpayne@11: sleep(1 if not api_key else 0.1) jpayne@12: response = requests.get(elink, params=dict(id=biosample, dbfrom="biosample", db="sra", format="json", **extra_params)) jpayne@4: response.raise_for_status() jpayne@4: reply = response.json() jpayne@4: for ids in batched(reply.get("linksets", [{}])[0].get("linksetdbs", [{}])[0].get("links", []), 200): jpayne@11: sleep(1 if not api_key else 0.1) jpayne@11: response = requests.get(esummary, params=dict(id=','.join(ids), db="sra", format="json", **extra_params)) jpayne@4: response.raise_for_status() jpayne@4: replyy = response.json() jpayne@4: for field, value in replyy.get("result", {}).items(): jpayne@4: if "uids" not in field: jpayne@4: yield field, value.get("runs") jpayne@4: jpayne@4: jpayne@4: runs_example = """ jpayne@4: jpayne@4: jpayne@4: """ jpayne@4: jpayne@4: def flatten_runs(runxml): jpayne@4: root = xml.fromstring(f"{runxml}") # gotta fix their garbage embedded XML since it isn't singly-rooted jpayne@4: for run in root.findall(".//Run"): jpayne@12: if run.attrib["is_public"] == "false": jpayne@12: logger.warning(f"Skipping non-public run {run.attrib['acc']}") jpayne@4: yield dict( jpayne@4: sra_run_accession = run.attrib["acc"], jpayne@4: total_spots = run.attrib["total_spots"], jpayne@4: total_bases = run.attrib["total_bases"], jpayne@4: ) jpayne@4: jpayne@4: jpayne@4: jpayne@4: def main(starting_bioproject): jpayne@4: rows = [] jpayne@4: response = requests.get(esearch, params=dict(db="bioproject", term=starting_bioproject, field="PRJA", format="json")) jpayne@4: response.raise_for_status() jpayne@4: reply = response.json() jpayne@4: try: jpayne@4: bioproject_id = reply["esearchresult"]["idlist"][0] jpayne@4: log(f"Found UID {bioproject_id} for '{starting_bioproject}'") jpayne@4: except IndexError: jpayne@4: logger.error(f"No results found for '{starting_bioproject}'. Error was \"{reply['esearchresult']['warninglist']['outputmessages']}\"") jpayne@4: sys.exit(1) jpayne@11: sleep(1 if not api_key else 0.1) jpayne@4: for bioproject, biosample, biosample_xml in resolve_bioproject_ids_and_links([(starting_bioproject, bioproject_id)]): jpayne@4: try: jpayne@4: sampledict = flatten_biosample_xml(biosample_xml) jpayne@4: except KeyError: jpayne@4: log(biosample_xml) jpayne@4: raise jpayne@4: sampledict["bioproject"] = bioproject jpayne@12: noruns = True jpayne@4: for sra, runs in yield_sra_runs_from_sample(biosample): jpayne@4: for run in flatten_runs(runs.strip()): jpayne@12: noruns = False jpayne@4: run.update(sampledict) jpayne@4: rows.append(run) jpayne@12: if noruns: jpayne@12: rows.append(sampledict) jpayne@4: jpayne@4: log(f"Writing {len(rows)} rows to metadata.tsv") jpayne@4: jpayne@4: header = set() jpayne@4: for row in rows: jpayne@4: for key in row.keys(): jpayne@4: header.add(key) jpayne@4: jpayne@4: header = sorted(list(header), key=hso) jpayne@12: # logger.info(f"Header: {header}") jpayne@4: jpayne@4: rows.sort(key=lambda x: x["biosample_accession"]) jpayne@4: jpayne@4: with open("metadata.tsv", "w") as f: jpayne@4: writer = csv.DictWriter(f, fieldnames=header, delimiter="\t", dialect="excel") jpayne@4: writer.writeheader() jpayne@4: writer.writerows(rows) jpayne@4: jpayne@12: # check for duplicate runs and unreleased samples jpayne@12: jpayne@12: accessions = [row.get("sra_run_accession") for row in rows if row.get("sra_run_accession")] jpayne@12: jpayne@12: raw_length = len(accessions) jpayne@12: jpayne@12: accessions = sorted(list(set(accessions))) jpayne@12: jpayne@12: if raw_length < len(rows): jpayne@12: logger.warning(f"Bioproject {starting_bioproject} contains unreleased samples. {len(rows) - raw_length} samples will not be included in accessions.txt") jpayne@12: jpayne@12: if len(accessions) < raw_length: jpayne@12: logger.warning(f"Some SRA runs may have been reached through multiple projects or samples. accessions.txt will be deduplicated but the metadata table is not") jpayne@12: jpayne@12: log(f"Writing {len(accessions)} unique accessions to accessions.txt") jpayne@4: jpayne@4: with open("accessions.txt", "w") as f: jpayne@12: f.writelines(accessions) jpayne@4: jpayne@4: jpayne@4: if __name__ == "__main__": jpayne@4: b = sys.argv[1].strip() jpayne@4: log(f"Starting with {b}") jpayne@4: try: jpayne@4: main(b) jpayne@4: except requests.HTTPError as e: jpayne@4: logger.error(e) jpayne@4: sys.exit(1) jpayne@4: jpayne@4: jpayne@4: jpayne@4: jpayne@4: