From 1420aaefddc8f7ad081766da6d43354e50b2ddb1 Mon Sep 17 00:00:00 2001 From: Kevin Vizhalil Date: Mon, 9 Sep 2024 15:46:39 -0400 Subject: [PATCH] #2352 Improve quality of xCRG Paths --- code/ARAX/ARAXQuery/ARAX_infer.py | 24 +++++- .../ARAXQuery/Infer/scripts/creativeCRG.py | 84 +++++++++++++++---- code/ARAX/BiolinkHelper/biolink_helper.py | 48 +++++++++-- 3 files changed, 129 insertions(+), 27 deletions(-) diff --git a/code/ARAX/ARAXQuery/ARAX_infer.py b/code/ARAX/ARAXQuery/ARAX_infer.py index b2be0e7a9..73d358624 100644 --- a/code/ARAX/ARAXQuery/ARAX_infer.py +++ b/code/ARAX/ARAXQuery/ARAX_infer.py @@ -27,6 +27,9 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs) from openapi_server.models.edge import Edge from openapi_server.models.attribute import Attribute as EdgeAttribute from openapi_server.models.node import Node +from openapi_server.models.qualifier import Qualifier +from openapi_server.models.qualifier_constraint import QualifierConstraint as QConstraint + sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'NodeSynonymizer'])) from node_synonymizer import NodeSynonymizer @@ -36,6 +39,7 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs) # from creativeDTD import creativeDTD from creativeCRG import creativeCRG from ExplianableDTD_db import ExplainableDTD + # from ExplianableCRG import ExplianableCRG # sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code'])) @@ -615,7 +619,7 @@ def __chemical_gene_regulation_graph_expansion(self, describe=False): f"The `n_result_curies` value must be a positive integer. The provided value was {self.parameters['n_result_curies']}.", error_code="ValueError") else: - self.parameters['n_result_curies'] = 10 + self.parameters['n_result_curies'] = 30 if 'n_paths' in self.parameters: if isinstance(self.parameters['n_paths'], str): @@ -678,9 +682,26 @@ def __chemical_gene_regulation_graph_expansion(self, describe=False): if not preferred_subject_curie and not preferred_object_curie: self.response.error(f"Both parameters 'subject_curie' and 'object_curie' are not provided. Please provide the curie for either one of them") return self.response + qedges = message.query_graph.edges + else: self.response.error(f"The 'query_graph' is detected. One of 'subject_qnode_id' or 'object_qnode_id' should be specified.") + + if self.parameters['regulation_type'] == 'increase': + edge_qualifier_direction = 'increased' + else: + edge_qualifier_direction = 'decreased' + edge_qualifier_list = [ + Qualifier(qualifier_type_id='biolink:object_aspect_qualifier', qualifier_value='activity_or_abundance'), + Qualifier(qualifier_type_id='biolink:object_direction_qualifier', qualifier_value=edge_qualifier_direction)] + + for qedge in qedges: + edge = message.query_graph.edges[qedge] + edge.knowledge_type = "inferred" + edge.predicates = ["biolink:affects"] + edge.qualifier_constraints = [QConstraint(qualifier_set=edge_qualifier_list)] + else: if 'subject_curie' in parameters or 'object_curie' in parameters: @@ -763,6 +784,7 @@ def __chemical_gene_regulation_graph_expansion(self, describe=False): iu = InferUtilities() qedge_id = self.parameters.get('qedge_id') + self.response, self.kedge_global_iter, self.qedge_global_iter, self.qnode_global_iter, self.option_global_iter = iu.genrete_regulate_subgraphs(self.response, None, normalized_object_curie, top_predictions, top_paths, qedge_id, self.parameters['regulation_type'], self.kedge_global_iter, self.qedge_global_iter, self.qnode_global_iter, self.option_global_iter) return self.response diff --git a/code/ARAX/ARAXQuery/Infer/scripts/creativeCRG.py b/code/ARAX/ARAXQuery/Infer/scripts/creativeCRG.py index d94f7d3e0..75c40d043 100644 --- a/code/ARAX/ARAXQuery/Infer/scripts/creativeCRG.py +++ b/code/ARAX/ARAXQuery/Infer/scripts/creativeCRG.py @@ -8,7 +8,6 @@ import requests # import graph_tool.all as gt from tqdm import tqdm, trange - pathlist = os.getcwd().split(os.path.sep) RTXindex = pathlist.index("RTX") sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'ARAXQuery'])) @@ -23,7 +22,9 @@ from RTXConfiguration import RTXConfiguration RTXConfig = RTXConfiguration() sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'ARAXQuery',''])) - +sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'BiolinkHelper',''])) +from biolink_helper import BiolinkHelper +def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs) def call_plover(curies: List, respect_predicate_symmetry: bool=False): json = {} plover_url = RTXConfig.plover_url @@ -164,6 +165,16 @@ def __init__(self, response: ARAXResponse, data_path: str): ## set up parameters self.response = response + self.bh = BiolinkHelper() + self.predicate_depth_map = self.bh.get_predicate_depth_map() + self.relevant_node_categories = ['biolink:Drug', 'biolink:PathologicalProcess', 'biolink:GeneOrGeneProduct', 'biolink:ChemicalEntity', + 'biolink:SmallMolecule', 'biolink:Gene', 'biolink:BiologicalProcess', 'biolink:Pathway', 'biolink:Disease', + 'biolink:Transcript', 'biolink:Cell', 'biolink:GeneFamily', 'biolink:GeneProduct', 'biolink:Exon', + 'biolink:DiseaseOrPhenotypicFeature', 'biolink:PhenotypicFeature', 'biolink:MolecularActivity', 'biolink:GeneGroupingMixin', + 'biolink:CellularComponent', 'biolink:RNAProduct', 'biolink:Protein', 'biolink:BiologicalProcessOrActivity', 'biolink:PhysiologicalProcess', + 'biolink:NoncodingRNAProduct', 'biolink:ProteinFamily', 'biolink:ProteinDomain'] + self.relevant_node_categories = self.bh.get_descendants(self.relevant_node_categories) + self.data_path = data_path self.chemical_type = ['biolink:ChemicalEntity', 'biolink:ChemicalMixture','biolink:SmallMolecule'] self.gene_type = ['biolink:Gene','biolink:Protein'] @@ -212,6 +223,7 @@ def get_tf_neighbors(self): for edge in edges.keys(): c1 = edges[edge][0] c2 = edges[edge][1] + depth = self.predicate_depth_map[edges[edge][2]] if 'subclass' in edges[edge][2]: continue if c1 == c2: @@ -219,11 +231,11 @@ def get_tf_neighbors(self): if c1 in self.tf_list: curie = c2 tf = c1 - answer_tf_neigbor_data.append({"edge_id": edge, "transcription_factor":tf, "neighbour": curie}) + answer_tf_neigbor_data.append({"edge_id": edge, "transcription_factor":tf, "neighbour": curie, "depth": depth}) if c2 in self.tf_list: curie = c1 tf = c2 - query_tf_neighbor_data.append({"edge_id": edge, "transcription_factor":tf, "neighbour": curie}) + query_tf_neighbor_data.append({"edge_id": edge, "transcription_factor":tf, "neighbour": curie, "depth": depth}) return query_tf_neighbor_data,answer_tf_neigbor_data, edges def add_node_ids_to_path(self, paths, tf_edges,chemical_edges, gene_edges): @@ -507,62 +519,90 @@ def _check_params(query_chemical: Optional[str], query_gene: Optional[str], mode else: top_paths = dict() + gene_neighbors = call_plover([preferred_query_gene]) answers = res['chemical_id'].tolist() self.preferred_curies = self.get_preferred_curies(answers) valid_chemicals = [item for item in self.preferred_curies.values() if item] chemical_neighbors = call_plover(valid_chemicals) query_tf_neighbors, answer_tf_neigbors, tf_edges = self.get_tf_neighbors() - paths = self.get_paths(preferred_query_gene, res['chemical_id'].tolist(), gene_neighbors, chemical_neighbors, query_tf_neighbors, answer_tf_neigbors,self.tf_list, M) final_paths = self.add_node_ids_to_path(paths, tf_edges, chemical_neighbors, gene_neighbors) return final_paths def get_paths(self, query_curie, answer_curies, query_neighbors, answer_neighbors, query_tf_neighbors, answer_tf_neighbors, tf_list,n_paths): - query_neighbors_curies = list(query_neighbors['nodes']['n01'].keys()) query_tf_neighbors_dict = {} answer_tf_neighbors_dict = {} query_path = {} answer_path = {} combined_path = dict() - one_hop_from_query = set(tf_list).intersection(query_neighbors_curies) + valid_query_tf_list = [] + valid_answer_tf_list = {} + for answer in answer_curies: + valid_answer_tf_list[answer] = [] + for record in query_tf_neighbors: - query_tf_neighbors_dict[record['neighbour']] = query_tf_neighbors_dict.get(record['neighbour'],[]) + [(record['edge_id'],record['transcription_factor'])] + query_tf_neighbors_dict[record['neighbour']] = query_tf_neighbors_dict.get(record['neighbour'],[]) + [(record['edge_id'],record['transcription_factor'], record['depth'])] for record in answer_tf_neighbors: - answer_tf_neighbors_dict[record['neighbour']] = answer_tf_neighbors_dict.get(record['neighbour'],[]) + [(record['edge_id'],record['transcription_factor'])] + answer_tf_neighbors_dict[record['neighbour']] = answer_tf_neighbors_dict.get(record['neighbour'],[]) + [(record['edge_id'],record['transcription_factor'], record['depth'])] # one hop from query for edge_id, edge in query_neighbors['edges']['e00'].items(): if edge[1] in tf_list and edge[1] not in query_path: - query_path[edge[1]] = [edge_id] - + valid_query_tf_list.append(edge[1]) + query_path[edge[1]] = [edge_id,self.predicate_depth_map[edge[2]]] + elif edge[1] in tf_list and edge[1] in query_path: + if query_path[edge[1]][-1] < self.predicate_depth_map[edge[2]]: + query_path[edge[1]] = [edge_id,self.predicate_depth_map[edge[2]]] # two hop from query for edge_id, edge in query_neighbors['edges']['e00'].items(): if edge[0] != query_curie: continue + relevant_node = False neighbor = edge[1] + neighbor_category = query_neighbors['nodes']['n01'][neighbor][1] + if neighbor_category in self.relevant_node_categories: + relevant_node = True + for item in query_tf_neighbors_dict.get(neighbor,[]): + if item[1] not in valid_query_tf_list and relevant_node: + valid_query_tf_list.append(item[1]) if item[1] not in query_path: - query_path[item[1]] = [edge_id,item[0]] + query_path[item[1]] = [edge_id,item[0], item[2]] + elif query_path[item[1]][-1] < min(item[2],self.predicate_depth_map[edge[2]]) and ((item[1] not in valid_query_tf_list) or relevant_node) : + query_path[item[1]] = [edge_id,item[0], min(item[2],self.predicate_depth_map[edge[2]])] + for edge_id, edge in answer_neighbors['edges']['e00'].items(): if edge[1] not in self.preferred_curies.values(): continue + relevant_node = False answer = edge[1] neighbor = edge[0] + neighbor_category = answer_neighbors['nodes']['n01'][neighbor][1] + if neighbor_category in self.relevant_node_categories: + relevant_node = True # one hop from answer if answer not in answer_path: answer_path[answer] = dict() - if neighbor in tf_list: - answer_path[answer][neighbor] = [edge_id] + if neighbor in tf_list and neighbor not in answer_path[answer]: + valid_answer_tf_list[answer].append(neighbor) + answer_path[answer][neighbor] = [edge_id, self.predicate_depth_map[edge[2]]] + elif neighbor in tf_list and neighbor in answer_path[answer]: + if answer_path[answer][neighbor][-1] < self.predicate_depth_map[edge[2]]: + answer_path[answer][neighbor] = [edge_id, self.predicate_depth_map[edge[2]]] # two hop from answer for item in answer_tf_neighbors_dict.get(neighbor,[]): + neighbor_category = answer_neighbors['nodes']['n01'][neighbor][1] + if relevant_node and item[1] not in valid_answer_tf_list[answer]: + valid_answer_tf_list[answer].append(item[1]) if item[1] not in answer_path[answer]: - answer_path[answer][item[1]] = [item[0],edge_id] - + answer_path[answer][item[1]] = [item[0],edge_id, item[2]] + elif answer_path[answer][item[1]][-1] < item[2] and ((item[1] not in valid_answer_tf_list[answer]) or relevant_node): + answer_path[answer][item[1]] = [item[0], edge_id, min(item[2],self.predicate_depth_map[edge[2]])] # joining paths for answer in answer_curies: combined_path[(query_curie,answer)] = list() @@ -573,13 +613,21 @@ def get_paths(self, query_curie, answer_curies, query_neighbors, answer_neighbor continue path_counter = 0 - for tf in tf_list: + relevant_tf = list(set(valid_query_tf_list).intersection(valid_answer_tf_list[answer])) + irrelevant_tf = [tf for tf in tf_list if tf not in relevant_tf] + for tf in relevant_tf: if path_counter > n_paths: break if tf in query_path and tf in answer_path[key]: - combined_path[(query_curie,answer)].append(query_path[tf] + answer_path[key][tf]) + combined_path[(query_curie,answer)].append(query_path[tf][:-1] + answer_path[key][tf][:-1]) path_counter += 1 + for tf in irrelevant_tf: + if path_counter > n_paths: + break + if tf in query_path and tf in answer_path[key]: + combined_path[(query_curie,answer)].append(query_path[tf][:-1] + answer_path[key][tf][:-1]) + path_counter += 1 return combined_path diff --git a/code/ARAX/BiolinkHelper/biolink_helper.py b/code/ARAX/BiolinkHelper/biolink_helper.py index f0b001767..bbb693be5 100644 --- a/code/ARAX/BiolinkHelper/biolink_helper.py +++ b/code/ARAX/BiolinkHelper/biolink_helper.py @@ -125,7 +125,18 @@ def get_canonical_predicates(self, predicates: Union[str, List[str], Set[str]]) for predicate in valid_predicates} canonical_predicates.update(invalid_predicates) # Go ahead and include those we don't have canonical info for return list(canonical_predicates) - + + def get_predicate_depth_map(self)->Dict[str,int]: + response = self._download_biolink_model() + if response.status_code == 200: + biolink_model = yaml.safe_load(response.text) + predicate_dag = self._build_predicate_dag(biolink_model) + + else: + raise RuntimeError(f"ERROR: Request to get Biolink {self.biolink_version} YAML file returned " + f"{response.status_code} response. Cannot load BiolinkHelper.") + return self._get_depths_from_root(predicate_dag) + def is_symmetric(self, predicate: str) -> Optional[bool]: if predicate in self.biolink_lookup_map["predicates"]: return self.biolink_lookup_map["predicates"][predicate]["is_symmetric"] @@ -198,7 +209,15 @@ def _load_biolink_lookup_map(self, is_test: bool = False): with open(self.biolink_lookup_map_path, "rb") as biolink_map_file: biolink_lookup_map = pickle.load(biolink_map_file) return biolink_lookup_map - + + def _download_biolink_model(self): + response = requests.get(f"https://raw.githubusercontent.com/biolink/biolink-model/{self.biolink_version}/biolink-model.yaml", + timeout=10) + if response.status_code != 200: # Sometimes Biolink's tags start with 'v', so try that + response = requests.get(f"https://raw.githubusercontent.com/biolink/biolink-model/v{self.biolink_version}/biolink-model.yaml", + timeout=10) + return response + def _create_biolink_lookup_map(self) -> Dict[str, Dict[str, Dict[str, Union[str, List[str], bool]]]]: timestamp = str(datetime.datetime.now().isoformat()) eprint(f"{timestamp}: INFO: Building local Biolink {self.biolink_version} ancestor/descendant lookup map " @@ -206,17 +225,14 @@ def _create_biolink_lookup_map(self) -> Dict[str, Dict[str, Dict[str, Union[str, biolink_lookup_map = {"predicates": dict(), "categories": dict(), "aspects": dict(), "directions": dict()} # Grab the relevant Biolink yaml file - response = requests.get(f"https://raw.githubusercontent.com/biolink/biolink-model/{self.biolink_version}/biolink-model.yaml", - timeout=10) - if response.status_code != 200: # Sometimes Biolink's tags start with 'v', so try that - response = requests.get(f"https://raw.githubusercontent.com/biolink/biolink-model/v{self.biolink_version}/biolink-model.yaml", - timeout=10) + response = self._download_biolink_model() if response.status_code == 200: biolink_model = yaml.safe_load(response.text) # -------------------------------- PREDICATES --------------------------------- # predicate_dag = self._build_predicate_dag(biolink_model) + import pdb;pdb.set_trace() # Build our map of predicate ancestors/descendants for easy lookup, first WITH mixins for node_id in list(predicate_dag.nodes): node_info = predicate_dag.nodes[node_id] @@ -382,7 +398,23 @@ def _build_direction_dag(self, biolink_model: dict) -> nx.DiGraph: direction_dag.add_edge(parent_name_trapi, direction_name_trapi) return direction_dag - + + def _get_depths_from_root(self, dag)-> Dict[str,int]: + node_depths = {} + for node in nx.topological_sort(dag): + # Skip if the node is the start node + + # Get all predecessors of the current node + predecessors = list(dag.predecessors(node)) + + # If the node has predecessors, calculate its depth as max(depth of predecessors) + 1 + if predecessors: + node_depths[node] = max(node_depths[pred] for pred in predecessors) + 1 + else: + node_depths[node] = 0 # Handle nodes that have no predecessors (if any) + + return node_depths + @staticmethod def _get_ancestors_nx(nx_graph: nx.DiGraph, node_id: str) -> List[str]: return list(nx.ancestors(nx_graph, node_id).union({node_id}))