view lexmapr/ontology_reasoner.py @ 1:5244e7465767

"planemo upload"
author kkonganti
date Wed, 31 Aug 2022 14:32:14 -0400
parents f5c39d0447be
children
line wrap: on
line source
"""Ontology finder and visualizer"""

import copy, json, logging, requests, time
import pygraphviz as pgv

logging.getLogger('urllib3').setLevel(logging.WARNING)


# TODO: figure out what to do with root Thing:Thing
class Ontology_accession:
    '''Base class for defining attributes and behavior of single ontology accesions;
       Assume format definition (whitespace and punctuation okay):ontology_id'''
    existing_ontologies = {}

    @staticmethod
    def make_instance(acc):
      '''Use instead of default __init__ to enforce one instance per ontology'''
      try:
          return(Ontology_accession.existing_ontologies[acc])
      except(KeyError):
          Ontology_accession.existing_ontologies[acc] = Ontology_accession(acc)
          return(Ontology_accession.existing_ontologies[acc])

    def __init__(self, acc):
        '''If ontology is not recognized, just use short form, ex THING'''
        def_split = acc.split(':')
        self.label = ':'.join(def_split[:-1])
        self.id = def_split[-1].replace('_',':')
        self.parents = 'not assigned yet'
        self.children = 'not assigned yet'
        self.ancestors = 'not assigned yet'
        self.descendants = 'not assigned yet'
        self.graph_nodes = 'not assigned yet'
        self.graph_fill = False
        self.ontology = def_split[1].split('_')[0]
        if self.label == '':
            self._get_label()

    def _api_results(self, input_list, return_list):
        '''Ignore obsolete terms, not currently checking for \'term_replaced_by\''''
        for x_term in input_list:
            if x_term['is_obsolete']:
                continue
            new_term = x_term['label'] + ':' + x_term['short_form']
            return_list.append(Ontology_accession.make_instance(new_term))
        return(return_list)

    def _add_edges(self, family_member, family_list, edge_set, round_num):
        '''Add edges to graph'''
        if edge_set == []:
            return(edge_set)
        elif round_num > 0:
            for x in family_list:
                x.get_family(family_member)
                if family_member == 'parents': # TODO: how get x.family_member to collapse code
                    if x.parents == ['none found']:
                        continue
                    if len(x.parents) > 5:
                        time.sleep(0.05)
                    new_edges = [(y._graph_label(),x._graph_label()) for y in x.parents]
                    edge_set = edge_set + [z for z in new_edges if z not in edge_set]
                    edge_set = x._add_edges(family_member, x.parents, edge_set, round_num-1)
                elif family_member == 'children':
                    if x.children == ['none found']:
                        continue
                    if len(x.children) > 5:
                        time.sleep(0.05)
                    new_edges = [(x._graph_label(),y._graph_label()) for y in x.children]
                    edge_set = edge_set + [z for z in new_edges if z not in edge_set]
                    edge_set = x._add_edges(family_member, x.children, edge_set, round_num-1)
        return(edge_set)

    def _draw_graph(self, o_file, node_color, edge_color):
        '''Draw and save the graph'''
        ontol_graph = pgv.AGraph(name='ontology_graph')
        ontol_graph.add_node(self._graph_label())
        for x in self.graph_nodes:
            ontol_graph.add_edge(x[0], x[1])
        ontol_graph.node_attr.update(shape='box',
                                     style='rounded,filled',
                                     fillcolor='lightgrey',
                                     color=node_color)
        ontol_graph.edge_attr.update(shape='normal',
                                     color=edge_color,
                                     dir='back')
        ontol_graph.get_node(self._graph_label()).attr.update(fillcolor='lightblue')
        # TODO: determine best algorithm: neato, fdp, nop, twopi; tried circo; not dot, sfdp
        ontol_graph.draw(o_file, prog='twopi')

    def _expand_edge(self, family_member, family_list, edge_set, old_set='', stop_terms=False):
        '''Add edges to graph'''
        while old_set != edge_set:
            old_set = copy.deepcopy(edge_set)
        for x in family_list:
            if x == 'none found':
                break
            if type(stop_terms) == list:
                if x in stop_terms:
                    break
            x.get_family(family_member)
            if family_member == 'parents': # TODO: how get x.family_member to collapse code
                if x.parents == ['none found']:
                    continue
                if len(x.parents) > 5:
                    time.sleep(0.05)
                new_edges = [(y._graph_label(),x._graph_label()) for y in x.parents]
                edge_set = edge_set + [z for z in new_edges if z not in edge_set]
                edge_set = x._expand_edge(family_member,x.parents,edge_set,old_set,stop_terms)
            elif family_member == 'children':
                if x.children == ['none found']:
                    continue
                if len(x.children) > 5:
                    time.sleep(0.05)
                new_edges = [(x._graph_label(),y._graph_label()) for y in x.children]
                edge_set = edge_set + [z for z in new_edges if z not in edge_set]
                edge_set = x._expand_edge(family_member,x.children,edge_set,old_set,stop_terms)
        return(edge_set)

    def _get_label(self):
        '''Retrieve definition is correct for an id; updates instance'''
        query_url = 'http://www.ebi.ac.uk/ols/api/terms?obo_id={}'.format(self.id)
        ols_resp = self._get_request(query_url)
        if ols_resp is None:
            logging.warning(f'Did not retrieve PURL for {self.id}')
            self.label = 'unk'
            return
        try:
            self.label = ols_resp.json()['_embedded']['terms'][0]['label']
        except(KeyError):
            logging.warning(f'Did not find label for {self.id} in OLS')
            self.label = 'unk'
        except json.decoder.JSONDecodeError as err:
            time.sleep(0.05)
            self._get_label()

    def _get_request(self, request_url, max_retries=5):
        '''Retrieve URL'''
        while max_retries > 0:
            try:
                return(requests.get(request_url))
            except:
                time.sleep(0.05)
            max_retries -= 1
        return(None)

    def _graph_label(self):
        '''Format a graph label'''
        return(self.id+'\\n'+self.label)

    def _next_page(self, url_link, return_list):
        '''Get next page of search results'''
        next_resp = self._get_request(url_link)
        if next_resp is None:
            logging.warning(f'Did not retrieve URL for {url_link} during API search')
            return(False, return_list)
        else:
            try:
                next_link = next_resp.json()['_links']['next']['href']
            except(KeyError):
                next_link = False
            return_list = self._api_results(next_resp.json()['_embedded']['terms'], return_list)
        return(next_link, return_list)

    def check_label(self):
        '''Check if given definition is correct for an id; returns Boolean or str `unk`'''
        self._get_label()
        if self.label != 'unk':
            return(ols_resp.json()['_embedded']['terms'][0]['label'] == self.label)
        else:
            return(self.label)

    def get_family(self, family_member):
        '''Returns list of parents, ancestors, children or descendants'''
        if family_member == 'parents' and self.parents != 'not assigned yet':
            return(self.parents)
        elif family_member == 'children' and self.children != 'not assigned yet': 
            return(self.children)
        elif family_member == 'ancestors' and self.ancestors != 'not assigned yet': 
            return(self.ancestors)
        elif family_member == 'descendants' and self.descendants != 'not assigned yet': 
            return(self.descendants)

        if self.id.split(':')[0].lower() == 'gaz':
            query_url = 'https://www.ebi.ac.uk/ols/api/ontologies/gaz/terms?iri='
            query_url += 'http://purl.obolibrary.org/obo/' + self.id.replace(':','_')
            ols_resp = self._get_request(query_url)
            qry_url = ols_resp.json()['_embedded']['terms'][0]['_links']\
                                     ['hierarchical'+family_member.title()]['href']
        else:
            query_url = 'http://www.ebi.ac.uk/ols/api/ontologies/{}/{}?id={}'
            qry_url = query_url.format(self.id.split(':')[0].lower(),family_member,self.id)

        ols_resp = self._get_request(qry_url)
        if ols_resp is None:
            logging.warning(f'Did not get URL for {url_link} during search for {family_member}')
            result_list = ['none found']
        elif ols_resp.status_code > 200:
            result_list = ['none found']
        elif ols_resp.json()['page']['totalElements'] > 0:
            result_list = self._api_results(ols_resp.json()['_embedded']['terms'], [])
            if ols_resp.json()['page']['totalPages'] > 1:
                next_url = ols_resp.json()['_links']['next']['href']
                while next_url:
                    next_url,result_list = self._next_page(next_url,result_list)
        else:
            result_list = ['none found']

        if family_member == 'parents':
            self.parents = list(set(result_list))
        elif family_member == 'children':
            self.children = list(set(result_list))
        elif family_member == 'ancestors':
            self.ancestors = list(set(result_list))
        elif family_member == 'descendants':
            self.descendants = list(set(result_list))
        return(result_list)

    def bin_term(self, bin_package):
        '''Categorize term into given bins as Ontology_package'''
        term_bins = []
        self.get_family('ancestors')
        if self.ancestors == ['none found']:
            ancestor_labels = [x.label + ':' + x.id.replace(':','_') for x in [self]]
        else:
            ancestor_labels = [x.label+':'+x.id.replace(':','_') for x in [self]+self.ancestors]
        return([x for x in ancestor_labels if x in bin_package.ontologies])

    def visualize_term(self, o_file, node_color='black', edge_color='black',
                       fill_out=False, stop_terms=False, draw_graph=True):
        '''Visualize one term'''
        if self.graph_nodes!='not assigned yet' and self.graph_fill==fill_out:
            if draw_graph:
                self._draw_graph(o_file, node_color, edge_color)
        else:
            self.get_family('parents')
            self.get_family('children')
            edge_set1,edge_set2 = [],[]
            if self.parents != ['none found']:
                edge_set1 = [(x._graph_label(),self._graph_label()) for x in self.parents]
            if self.children != ['none found']:
                edge_set2 = [(self._graph_label(),x._graph_label()) for x in self.children]
            if type(fill_out) == int:
                edge_set1 = self._add_edges('parents', self.parents, edge_set1, fill_out-1)
                edge_set2 = self._add_edges('children', self.children, edge_set2, fill_out-1)
            elif fill_out==True:
                edge_set1 = self._expand_edge('parents',self.parents,edge_set1,'',stop_terms)
                edge_set2 = self._expand_edge('children',self.children,edge_set2,'',stop_terms)
            self.graph_nodes = list(set(edge_set1+edge_set2))
            if draw_graph:
                self._draw_graph(o_file, node_color, edge_color)


class Ontology_package:
    '''Associate or package Ontology_accession objects together'''
    def __init__(self, package_label, ontol_list):
        self.label = package_label
        self.ontologies = ontol_list
        self.bins = []
        self.lcp = 'not assigned yet'
        self.hcc = 'not assigned yet'
        self._lcp_state = (True,[])
        self._hcc_state = (True,[])
        self._bin_state = []
        self.graph_nodes = 'not assigned yet'
        self.graph_state = False

    def _common_family(self,family_member,incl_terms,excl_terms):
        '''Find common family members'''
        family_candidates = {}
        for ontol_term in [x for x in self.ontologies if x.id not in excl_terms]:
            family_candidates[ontol_term] = ontol_term.get_family(family_member)
        common_members = self._common_list(family_candidates, incl_terms)
        while common_members == []:
            for ontol_term in [x for x in self.ontologies if x.id not in excl_terms]:
                if len(self.ontologies) > 30:
                    time.sleep(0.05)
                original_list = list(family_candidates[ontol_term])
                for family_ontol in original_list:
                    if len(original_list) > 30:
                        time.sleep(0.05)
                    try:
                        family_candidates[ontol_term].extend(\
                                                      family_ontol.get_family(family_member))
                    except(AttributeError):
                        family_candidates[ontol_term].extend(['none found'])
        return(common_members)

    def _common_list(self, input_dic, incl_terms):
        '''Compare input dictionary keys and list'''
        term_lists = []
        for ontol_key in input_dic:
            append_list = [ontol_key]
            for ontol_val in input_dic[ontol_key]:
                append_list.append(ontol_val)
            term_lists.append(append_list)
        common_set = set.intersection(*map(set, term_lists))
        if incl_terms:
          common_keys = []
          for ontol_acc in common_set:
              if ontol_acc in input_dic.keys():
                  common_keys.append(ontol_acc)
              if common_keys != []:
                  return(common_keys)
        return(list(common_set - set(input_dic.keys())))

    def _draw_graph(self, o_file, node_color, edge_color, show_lcp, show_hcc):
        '''Draw and save graph'''
        ontol_graph = pgv.AGraph(name='ontology_graph')
        for x in self.ontologies:
            ontol_graph.add_node(x._graph_label())
        for x in self.graph_nodes:
            ontol_graph.add_edge(x[0], x[1])
        ontol_graph.node_attr.update(shape='box', style='rounded,filled',
                                     fillcolor='lightgrey', color=node_color)
        ontol_graph.edge_attr.update(shape='normal', color=edge_color, dir='back')
        if show_lcp:
            for x in self.lcp:
                ontol_graph.get_node(x._graph_label()).attr.update(fillcolor='beige')
        if show_hcc:
            for x in self.hcc:
                ontol_graph.get_node(x._graph_label()).attr.update(fillcolor='beige')
        for x in self.ontologies:
            ontol_graph.get_node(x._graph_label()).attr.update(fillcolor='lightblue')
        ontol_graph.draw(o_file,prog='dot')

    def _list_hierarchy(self, input_list, input_position):
        '''Get lowest or highest terms'''
        if input_list == ['none found']:
            return(input_list)
        family_lists = {}
        for input_term in input_list:
            if len(input_list) > 30: time.sleep(0.05)
            if input_position == 'lowest':
                if input_term == 'none found':
                    family_list = 'none found'
                else:
                    family_list = input_term.get_family('ancestors')
            elif input_position == 'highest':
                if input_term == 'none found':
                    family_list = 'none found'
                else:
                    family_list = input_term.get_family('descendants')
            family_lists[input_term] = family_list
        while True:
            remove_terms = []
            for input_term in input_list:
                if [True for f_l in family_lists if input_term in family_lists[f_l]] != []:
                    del family_lists[input_term]
                    remove_terms.append(input_term)
            if remove_terms != []:
                for x_term in remove_terms:
                    input_list.remove(x_term)
            else:
                break
        return(input_list)

    def _trim_tips(self):
        '''Remove descendants of self.ontologies and parents of self.lcp'''
        tip_nodes = [x._graph_label() for x in self.ontologies] +\
                    [x._graph_label() for x in self.lcp]
        old_nodes = []
        while old_nodes != self.graph_nodes:
            old_nodes = self.graph_nodes
            right_nodes = set()
            left_nodes = set()
            for x in self.graph_nodes:
                left_nodes.add(x[0])
                right_nodes.add(x[1])
            top_nodes = [x for x in left_nodes.difference(right_nodes) if x not in tip_nodes]
            bot_nodes = [x for x in right_nodes.difference(left_nodes) if x not in tip_nodes]
            self.graph_nodes = [x for x in self.graph_nodes if x[0] not in top_nodes]
            self.graph_nodes = [x for x in self.graph_nodes if x[1] not in bot_nodes]

    def get_lcp(self, incl_terms=True, excl_terms=[]): # TODO: missing excl_terms
        '''Find lowest common parent(s); can include input terms as lcp,
           exclude terms by obo id; saves results in lcp attribute'''
        if self._lcp_state == (incl_terms, excl_terms):
            if self.lcp != 'not assigned yet':
                return
        common_members = self._common_family('parents',incl_terms, excl_terms)
        common_members = self._list_hierarchy(common_members, 'lowest')
        if common_members != []:
            self.lcp = common_members
            self._lcp_state = (incl_terms, excl_terms)

    def get_hcc(self, incl_terms=True, excl_terms=[]):
        '''Get highest common child(ren); can include input terms as hcc;
           exclude terms by obo id; saves results in hcc attribute'''
        if self._hcc_state == (incl_terms, excl_terms):
            if self.hcc != 'not assigned yet':
                return
        common_members = self._common_family('children', incl_terms, excl_terms)
        common_members = self._list_hierarchy(common_members, 'highest')
        if common_members != []:
            self.hcc = common_members
            self._hcc_state = (incl_terms, excl_terms)

    def set_lcp(self, lcp_acc, incl_terms=True, excl_terms=[]):
        self.lcp = lcp_acc
        self._lcp_state = (incl_terms, excl_terms)

    def set_hcc(self, hcc_acc, incl_terms=True, excl_terms=[]):
        self.hcc = hcc_acc
        self._hcc_state = (incl_terms, excl_terms)

    def bin_terms(self, bin_package):
        '''Categorize terms by those in Ontology_package; saves results in bins attribute'''
        if self._bin_state == bin_package:
            return
        package_bins = []
        for x in self.ontologies:
            package_bins.extend(x.bin_term(bin_package))
        self.bins = list(set(package_bins))

    def visualize_terms(self, o_file, fill_out=False, show_lcp=False, show_hcc=False,
                                      node_color='black', edge_color='black',
                                      lcp_stop=False, hcc_stop=False, trim_nodes=False):
        '''Visualize terms'''
        if self.graph_nodes=='not assigned yet' or self.graph_fill!=fill_out:
            self.graph_nodes = []
        for x in self.ontologies:
            if lcp_stop and not hcc_stop:
                if x in self.lcp:
                    continue
                x.visualize_term(o_file, fill_out=fill_out,
                                 stop_terms=self.lcp, draw_graph=False)
            elif hcc_stop and not lcp_stop:
                if x in self.hcc:
                    continue
                x.visualize_term(o_file, fill_out=fill_out, 
                                 stop_terms=self.hcc, draw_graph=False)
            elif hcc_stop and lcp_stop:
                if x in self.lcp+self.hcc:
                    continue
                x.visualize_term(o_file, fill_out=fill_out,
                                 stop_terms=self.lcp+self.hcc, draw_graph=False)
            else:
                x.visualize_term(o_file, fill_out=fill_out, draw_graph=False)
            self.graph_nodes.extend([z for z in x.graph_nodes if z not in self.graph_nodes])
        if trim_nodes:
            self._trim_tips()
        if len(self.graph_nodes) > 150:
            edge_string = 'Parent node\tChild node'
            for edge_tuple in self.graph_nodes:
                edge_string += '\n'+'\t'.join(edge_tuple)
            logging.info(f'Not drawing graph with {len(self.graph_nodes)} edges:\
                           \n\n{edge_string}\n')
        else:
            self._draw_graph(o_file,node_color,edge_color,show_lcp,show_hcc)