Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/RTXteam/RTX
Browse files Browse the repository at this point in the history
  • Loading branch information
amykglen committed Sep 9, 2024
2 parents fd5a0a0 + 60a356b commit d5a79cb
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 27 deletions.
24 changes: 23 additions & 1 deletion code/ARAX/ARAXQuery/ARAX_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs)
from openapi_server.models.edge import Edge
from openapi_server.models.attribute import Attribute as EdgeAttribute
from openapi_server.models.node import Node
from openapi_server.models.qualifier import Qualifier
from openapi_server.models.qualifier_constraint import QualifierConstraint as QConstraint


sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'NodeSynonymizer']))
from node_synonymizer import NodeSynonymizer
Expand All @@ -36,6 +39,7 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs)
# from creativeDTD import creativeDTD
from creativeCRG import creativeCRG
from ExplianableDTD_db import ExplainableDTD

# from ExplianableCRG import ExplianableCRG

# sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code']))
Expand Down Expand Up @@ -615,7 +619,7 @@ def __chemical_gene_regulation_graph_expansion(self, describe=False):
f"The `n_result_curies` value must be a positive integer. The provided value was {self.parameters['n_result_curies']}.",
error_code="ValueError")
else:
self.parameters['n_result_curies'] = 10
self.parameters['n_result_curies'] = 30

if 'n_paths' in self.parameters:
if isinstance(self.parameters['n_paths'], str):
Expand Down Expand Up @@ -678,9 +682,26 @@ def __chemical_gene_regulation_graph_expansion(self, describe=False):
if not preferred_subject_curie and not preferred_object_curie:
self.response.error(f"Both parameters 'subject_curie' and 'object_curie' are not provided. Please provide the curie for either one of them")
return self.response
qedges = message.query_graph.edges


else:
self.response.error(f"The 'query_graph' is detected. One of 'subject_qnode_id' or 'object_qnode_id' should be specified.")

if self.parameters['regulation_type'] == 'increase':
edge_qualifier_direction = 'increased'
else:
edge_qualifier_direction = 'decreased'
edge_qualifier_list = [
Qualifier(qualifier_type_id='biolink:object_aspect_qualifier', qualifier_value='activity_or_abundance'),
Qualifier(qualifier_type_id='biolink:object_direction_qualifier', qualifier_value=edge_qualifier_direction)]

for qedge in qedges:
edge = message.query_graph.edges[qedge]
edge.knowledge_type = "inferred"
edge.predicates = ["biolink:affects"]
edge.qualifier_constraints = [QConstraint(qualifier_set=edge_qualifier_list)]


else:
if 'subject_curie' in parameters or 'object_curie' in parameters:
Expand Down Expand Up @@ -763,6 +784,7 @@ def __chemical_gene_regulation_graph_expansion(self, describe=False):

iu = InferUtilities()
qedge_id = self.parameters.get('qedge_id')

self.response, self.kedge_global_iter, self.qedge_global_iter, self.qnode_global_iter, self.option_global_iter = iu.genrete_regulate_subgraphs(self.response, None, normalized_object_curie, top_predictions, top_paths, qedge_id, self.parameters['regulation_type'], self.kedge_global_iter, self.qedge_global_iter, self.qnode_global_iter, self.option_global_iter)

return self.response
Expand Down
84 changes: 66 additions & 18 deletions code/ARAX/ARAXQuery/Infer/scripts/creativeCRG.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import requests
# import graph_tool.all as gt
from tqdm import tqdm, trange

pathlist = os.getcwd().split(os.path.sep)
RTXindex = pathlist.index("RTX")
sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'ARAXQuery']))
Expand All @@ -23,7 +22,9 @@
from RTXConfiguration import RTXConfiguration
RTXConfig = RTXConfiguration()
sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'ARAXQuery','']))

sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'BiolinkHelper','']))
from biolink_helper import BiolinkHelper
def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs)
def call_plover(curies: List, respect_predicate_symmetry: bool=False):
json = {}
plover_url = RTXConfig.plover_url
Expand Down Expand Up @@ -164,6 +165,16 @@ def __init__(self, response: ARAXResponse, data_path: str):

## set up parameters
self.response = response
self.bh = BiolinkHelper()
self.predicate_depth_map = self.bh.get_predicate_depth_map()
self.relevant_node_categories = ['biolink:Drug', 'biolink:PathologicalProcess', 'biolink:GeneOrGeneProduct', 'biolink:ChemicalEntity',
'biolink:SmallMolecule', 'biolink:Gene', 'biolink:BiologicalProcess', 'biolink:Pathway', 'biolink:Disease',
'biolink:Transcript', 'biolink:Cell', 'biolink:GeneFamily', 'biolink:GeneProduct', 'biolink:Exon',
'biolink:DiseaseOrPhenotypicFeature', 'biolink:PhenotypicFeature', 'biolink:MolecularActivity', 'biolink:GeneGroupingMixin',
'biolink:CellularComponent', 'biolink:RNAProduct', 'biolink:Protein', 'biolink:BiologicalProcessOrActivity', 'biolink:PhysiologicalProcess',
'biolink:NoncodingRNAProduct', 'biolink:ProteinFamily', 'biolink:ProteinDomain']
self.relevant_node_categories = self.bh.get_descendants(self.relevant_node_categories)

self.data_path = data_path
self.chemical_type = ['biolink:ChemicalEntity', 'biolink:ChemicalMixture','biolink:SmallMolecule']
self.gene_type = ['biolink:Gene','biolink:Protein']
Expand Down Expand Up @@ -212,18 +223,19 @@ def get_tf_neighbors(self):
for edge in edges.keys():
c1 = edges[edge][0]
c2 = edges[edge][1]
depth = self.predicate_depth_map[edges[edge][2]]
if 'subclass' in edges[edge][2]:
continue
if c1 == c2:
continue
if c1 in self.tf_list:
curie = c2
tf = c1
answer_tf_neigbor_data.append({"edge_id": edge, "transcription_factor":tf, "neighbour": curie})
answer_tf_neigbor_data.append({"edge_id": edge, "transcription_factor":tf, "neighbour": curie, "depth": depth})
if c2 in self.tf_list:
curie = c1
tf = c2
query_tf_neighbor_data.append({"edge_id": edge, "transcription_factor":tf, "neighbour": curie})
query_tf_neighbor_data.append({"edge_id": edge, "transcription_factor":tf, "neighbour": curie, "depth": depth})
return query_tf_neighbor_data,answer_tf_neigbor_data, edges

def add_node_ids_to_path(self, paths, tf_edges,chemical_edges, gene_edges):
Expand Down Expand Up @@ -507,62 +519,90 @@ def _check_params(query_chemical: Optional[str], query_gene: Optional[str], mode
else:

top_paths = dict()

gene_neighbors = call_plover([preferred_query_gene])
answers = res['chemical_id'].tolist()
self.preferred_curies = self.get_preferred_curies(answers)
valid_chemicals = [item for item in self.preferred_curies.values() if item]
chemical_neighbors = call_plover(valid_chemicals)
query_tf_neighbors, answer_tf_neigbors, tf_edges = self.get_tf_neighbors()

paths = self.get_paths(preferred_query_gene, res['chemical_id'].tolist(), gene_neighbors, chemical_neighbors, query_tf_neighbors, answer_tf_neigbors,self.tf_list, M)
final_paths = self.add_node_ids_to_path(paths, tf_edges, chemical_neighbors, gene_neighbors)
return final_paths


def get_paths(self, query_curie, answer_curies, query_neighbors, answer_neighbors, query_tf_neighbors, answer_tf_neighbors, tf_list,n_paths):
query_neighbors_curies = list(query_neighbors['nodes']['n01'].keys())
query_tf_neighbors_dict = {}
answer_tf_neighbors_dict = {}
query_path = {}
answer_path = {}
combined_path = dict()
one_hop_from_query = set(tf_list).intersection(query_neighbors_curies)
valid_query_tf_list = []
valid_answer_tf_list = {}
for answer in answer_curies:
valid_answer_tf_list[answer] = []

for record in query_tf_neighbors:
query_tf_neighbors_dict[record['neighbour']] = query_tf_neighbors_dict.get(record['neighbour'],[]) + [(record['edge_id'],record['transcription_factor'])]
query_tf_neighbors_dict[record['neighbour']] = query_tf_neighbors_dict.get(record['neighbour'],[]) + [(record['edge_id'],record['transcription_factor'], record['depth'])]
for record in answer_tf_neighbors:
answer_tf_neighbors_dict[record['neighbour']] = answer_tf_neighbors_dict.get(record['neighbour'],[]) + [(record['edge_id'],record['transcription_factor'])]
answer_tf_neighbors_dict[record['neighbour']] = answer_tf_neighbors_dict.get(record['neighbour'],[]) + [(record['edge_id'],record['transcription_factor'], record['depth'])]
# one hop from query
for edge_id, edge in query_neighbors['edges']['e00'].items():
if edge[1] in tf_list and edge[1] not in query_path:
query_path[edge[1]] = [edge_id]

valid_query_tf_list.append(edge[1])
query_path[edge[1]] = [edge_id,self.predicate_depth_map[edge[2]]]
elif edge[1] in tf_list and edge[1] in query_path:
if query_path[edge[1]][-1] < self.predicate_depth_map[edge[2]]:
query_path[edge[1]] = [edge_id,self.predicate_depth_map[edge[2]]]

# two hop from query
for edge_id, edge in query_neighbors['edges']['e00'].items():
if edge[0] != query_curie:
continue
relevant_node = False
neighbor = edge[1]
neighbor_category = query_neighbors['nodes']['n01'][neighbor][1]
if neighbor_category in self.relevant_node_categories:
relevant_node = True

for item in query_tf_neighbors_dict.get(neighbor,[]):
if item[1] not in valid_query_tf_list and relevant_node:
valid_query_tf_list.append(item[1])
if item[1] not in query_path:
query_path[item[1]] = [edge_id,item[0]]
query_path[item[1]] = [edge_id,item[0], item[2]]
elif query_path[item[1]][-1] < min(item[2],self.predicate_depth_map[edge[2]]) and ((item[1] not in valid_query_tf_list) or relevant_node) :
query_path[item[1]] = [edge_id,item[0], min(item[2],self.predicate_depth_map[edge[2]])]



for edge_id, edge in answer_neighbors['edges']['e00'].items():
if edge[1] not in self.preferred_curies.values():
continue
relevant_node = False
answer = edge[1]
neighbor = edge[0]
neighbor_category = answer_neighbors['nodes']['n01'][neighbor][1]
if neighbor_category in self.relevant_node_categories:
relevant_node = True
# one hop from answer
if answer not in answer_path:
answer_path[answer] = dict()
if neighbor in tf_list:
answer_path[answer][neighbor] = [edge_id]
if neighbor in tf_list and neighbor not in answer_path[answer]:
valid_answer_tf_list[answer].append(neighbor)
answer_path[answer][neighbor] = [edge_id, self.predicate_depth_map[edge[2]]]
elif neighbor in tf_list and neighbor in answer_path[answer]:
if answer_path[answer][neighbor][-1] < self.predicate_depth_map[edge[2]]:
answer_path[answer][neighbor] = [edge_id, self.predicate_depth_map[edge[2]]]

# two hop from answer
for item in answer_tf_neighbors_dict.get(neighbor,[]):
neighbor_category = answer_neighbors['nodes']['n01'][neighbor][1]
if relevant_node and item[1] not in valid_answer_tf_list[answer]:
valid_answer_tf_list[answer].append(item[1])
if item[1] not in answer_path[answer]:
answer_path[answer][item[1]] = [item[0],edge_id]

answer_path[answer][item[1]] = [item[0],edge_id, item[2]]
elif answer_path[answer][item[1]][-1] < item[2] and ((item[1] not in valid_answer_tf_list[answer]) or relevant_node):
answer_path[answer][item[1]] = [item[0], edge_id, min(item[2],self.predicate_depth_map[edge[2]])]
# joining paths
for answer in answer_curies:
combined_path[(query_curie,answer)] = list()
Expand All @@ -573,13 +613,21 @@ def get_paths(self, query_curie, answer_curies, query_neighbors, answer_neighbor
continue

path_counter = 0
for tf in tf_list:
relevant_tf = list(set(valid_query_tf_list).intersection(valid_answer_tf_list[answer]))
irrelevant_tf = [tf for tf in tf_list if tf not in relevant_tf]
for tf in relevant_tf:
if path_counter > n_paths:
break
if tf in query_path and tf in answer_path[key]:
combined_path[(query_curie,answer)].append(query_path[tf] + answer_path[key][tf])
combined_path[(query_curie,answer)].append(query_path[tf][:-1] + answer_path[key][tf][:-1])
path_counter += 1

for tf in irrelevant_tf:
if path_counter > n_paths:
break
if tf in query_path and tf in answer_path[key]:
combined_path[(query_curie,answer)].append(query_path[tf][:-1] + answer_path[key][tf][:-1])
path_counter += 1
return combined_path


48 changes: 40 additions & 8 deletions code/ARAX/BiolinkHelper/biolink_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,18 @@ def get_canonical_predicates(self, predicates: Union[str, List[str], Set[str]])
for predicate in valid_predicates}
canonical_predicates.update(invalid_predicates) # Go ahead and include those we don't have canonical info for
return list(canonical_predicates)


def get_predicate_depth_map(self)->Dict[str,int]:
response = self._download_biolink_model()
if response.status_code == 200:
biolink_model = yaml.safe_load(response.text)
predicate_dag = self._build_predicate_dag(biolink_model)

else:
raise RuntimeError(f"ERROR: Request to get Biolink {self.biolink_version} YAML file returned "
f"{response.status_code} response. Cannot load BiolinkHelper.")
return self._get_depths_from_root(predicate_dag)

def is_symmetric(self, predicate: str) -> Optional[bool]:
if predicate in self.biolink_lookup_map["predicates"]:
return self.biolink_lookup_map["predicates"][predicate]["is_symmetric"]
Expand Down Expand Up @@ -198,25 +209,30 @@ def _load_biolink_lookup_map(self, is_test: bool = False):
with open(self.biolink_lookup_map_path, "rb") as biolink_map_file:
biolink_lookup_map = pickle.load(biolink_map_file)
return biolink_lookup_map


def _download_biolink_model(self):
response = requests.get(f"https://raw.githubusercontent.com/biolink/biolink-model/{self.biolink_version}/biolink-model.yaml",
timeout=10)
if response.status_code != 200: # Sometimes Biolink's tags start with 'v', so try that
response = requests.get(f"https://raw.githubusercontent.com/biolink/biolink-model/v{self.biolink_version}/biolink-model.yaml",
timeout=10)
return response

def _create_biolink_lookup_map(self) -> Dict[str, Dict[str, Dict[str, Union[str, List[str], bool]]]]:
timestamp = str(datetime.datetime.now().isoformat())
eprint(f"{timestamp}: INFO: Building local Biolink {self.biolink_version} ancestor/descendant lookup map "
f"because one doesn't yet exist")
biolink_lookup_map = {"predicates": dict(), "categories": dict(),
"aspects": dict(), "directions": dict()}
# Grab the relevant Biolink yaml file
response = requests.get(f"https://raw.githubusercontent.com/biolink/biolink-model/{self.biolink_version}/biolink-model.yaml",
timeout=10)
if response.status_code != 200: # Sometimes Biolink's tags start with 'v', so try that
response = requests.get(f"https://raw.githubusercontent.com/biolink/biolink-model/v{self.biolink_version}/biolink-model.yaml",
timeout=10)
response = self._download_biolink_model()

if response.status_code == 200:
biolink_model = yaml.safe_load(response.text)

# -------------------------------- PREDICATES --------------------------------- #
predicate_dag = self._build_predicate_dag(biolink_model)
import pdb;pdb.set_trace()
# Build our map of predicate ancestors/descendants for easy lookup, first WITH mixins
for node_id in list(predicate_dag.nodes):
node_info = predicate_dag.nodes[node_id]
Expand Down Expand Up @@ -382,7 +398,23 @@ def _build_direction_dag(self, biolink_model: dict) -> nx.DiGraph:
direction_dag.add_edge(parent_name_trapi, direction_name_trapi)

return direction_dag


def _get_depths_from_root(self, dag)-> Dict[str,int]:
node_depths = {}
for node in nx.topological_sort(dag):
# Skip if the node is the start node

# Get all predecessors of the current node
predecessors = list(dag.predecessors(node))

# If the node has predecessors, calculate its depth as max(depth of predecessors) + 1
if predecessors:
node_depths[node] = max(node_depths[pred] for pred in predecessors) + 1
else:
node_depths[node] = 0 # Handle nodes that have no predecessors (if any)

return node_depths

@staticmethod
def _get_ancestors_nx(nx_graph: nx.DiGraph, node_id: str) -> List[str]:
return list(nx.ancestors(nx_graph, node_id).union({node_id}))
Expand Down

0 comments on commit d5a79cb

Please sign in to comment.