Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
mohsenht committed Sep 23, 2024
2 parents cea2ff5 + 00dc4b7 commit 5fcd1a2
Show file tree
Hide file tree
Showing 25 changed files with 210 additions and 120 deletions.
39 changes: 5 additions & 34 deletions code/ARAX/ARAXQuery/ARAX_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def __init__(self):
"aggregator_knowledge_source": {"==": "*"}}
self.supported_qedge_qualifier_constraints = {"biolink:qualified_predicate", "biolink:object_direction_qualifier",
"biolink:object_aspect_qualifier"}
self.higher_level_treats_predicates = {"biolink:treats_or_applied_or_studied_to_treat",
"biolink:applied_to_treat",
"biolink:studied_to_treat"}
self.treats_like_predicates = set(self.bh.get_descendants("biolink:treats_or_applied_or_studied_to_treat")).difference({"biolink:treats"})

def describe_me(self):
"""
Expand Down Expand Up @@ -353,7 +351,8 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
if inferred_qedge_keys and len(query_graph.edges) == 1:
for edge in query_sub_graph.edges.keys():
query_sub_graph.edges[edge].knowledge_type = 'lookup'
# Expand the query graph edge-by-edge

# Expand the query graph edge-by-edge (in regular 'lookup' fashion)
for qedge_key in ordered_qedge_keys_to_expand:
log.debug(f"Expanding qedge {qedge_key}")
response.update_query_plan(qedge_key, 'edge_properties', 'status', 'Expanding')
Expand Down Expand Up @@ -511,7 +510,7 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
# Remove KG2 SemMedDB treats_or_applied-type edges if this is an inferred treats query
if alter_kg2_treats_edges:
edge_keys_to_remove = {edge_key for edge_key, edge in overarching_kg.edges_by_qg_id[qedge_key].items()
if edge.predicate in self.higher_level_treats_predicates and
if edge.predicate in self.treats_like_predicates and
any(source.resource_id == "infores:rtx-kg2" for source in edge.sources) and
any(source.resource_id == "infores:semmeddb" for source in edge.sources)}
log.debug(f"Removing {len(edge_keys_to_remove)} KG2 semmeddb treats_or_applied-type edges "
Expand Down Expand Up @@ -583,9 +582,6 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
decorator = ARAXDecorator()
decorator.decorate_nodes(response)
decorator.decorate_edges(response, kind="RTX-KG2")

# Override node types to only include descendants of what was asked for in the QG (where applicable) #1360
self._override_node_categories(message.knowledge_graph, message.query_graph, log)
elif mode == "RTXKG2":
decorator = ARAXDecorator()
decorator.decorate_edges(response, kind="SEMMEDDB")
Expand All @@ -595,7 +591,7 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
num_edges_altered = 0
for edge in message.knowledge_graph.edges.values():
is_kg2_edge = any(source.resource_id == "infores:rtx-kg2" for source in edge.sources)
if is_kg2_edge and edge.predicate in self.higher_level_treats_predicates:
if is_kg2_edge and edge.predicate in self.treats_like_predicates:
# Record the original KG2 predicate in an attribute
edge.attributes.append(Attribute(attribute_type_id="biolink:original_predicate",
value=edge.predicate,
Expand Down Expand Up @@ -1400,31 +1396,6 @@ def _load_fda_approved_drug_ids() -> Set[str]:
fda_approved_drug_ids = pickle.load(fda_pickle)
return fda_approved_drug_ids

def _override_node_categories(self, kg: KnowledgeGraph, qg: QueryGraph, log: ARAXResponse):
# Clean up what we list as the TRAPI node.categories; list descendants of what was asked for in the QG
log.debug(f"Overriding node categories to better align with what's in the QG")
qnode_descendant_categories_map = {qnode_key: set(self.bh.get_descendants(qnode.categories))
for qnode_key, qnode in qg.nodes.items() if qnode.categories}
for node_key, node in kg.nodes.items():
final_categories = set()
for qnode_key in node.qnode_keys:
# If qnode has categories specified, use node's all_categories that are descendants of qnode categories
if qnode_key in qnode_descendant_categories_map:
all_categories_attributes = [attribute for attribute in eu.convert_to_list(node.attributes)
if attribute.attribute_type_id == "biolink:category"]
node_categories = all_categories_attributes[0].value if all_categories_attributes else node.categories
relevant_categories = set(node_categories).intersection(qnode_descendant_categories_map[qnode_key])
# Otherwise just use what's already in the node's categories (for KG2 this is the 'preferred' category)
else:
relevant_categories = set(node.categories)
final_categories = final_categories.union(relevant_categories)
if final_categories:
node.categories = list(final_categories)
else:
# Leave categories as they are but issue a warning
log.warning(f"None of the categories KPs gave node {node_key} ({node.categories}) are descendants of "
f"those asked for in the QG (for qnode {node.qnode_keys})")

@staticmethod
def _map_back_to_input_curies(kg: KnowledgeGraph, qg: QueryGraph, log: ARAXResponse):
"""
Expand Down
24 changes: 23 additions & 1 deletion code/ARAX/ARAXQuery/ARAX_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs)
from openapi_server.models.edge import Edge
from openapi_server.models.attribute import Attribute as EdgeAttribute
from openapi_server.models.node import Node
from openapi_server.models.qualifier import Qualifier
from openapi_server.models.qualifier_constraint import QualifierConstraint as QConstraint


sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'NodeSynonymizer']))
from node_synonymizer import NodeSynonymizer
Expand All @@ -36,6 +39,7 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs)
# from creativeDTD import creativeDTD
from creativeCRG import creativeCRG
from ExplianableDTD_db import ExplainableDTD

# from ExplianableCRG import ExplianableCRG

# sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code']))
Expand Down Expand Up @@ -615,7 +619,7 @@ def __chemical_gene_regulation_graph_expansion(self, describe=False):
f"The `n_result_curies` value must be a positive integer. The provided value was {self.parameters['n_result_curies']}.",
error_code="ValueError")
else:
self.parameters['n_result_curies'] = 10
self.parameters['n_result_curies'] = 30

if 'n_paths' in self.parameters:
if isinstance(self.parameters['n_paths'], str):
Expand Down Expand Up @@ -678,9 +682,26 @@ def __chemical_gene_regulation_graph_expansion(self, describe=False):
if not preferred_subject_curie and not preferred_object_curie:
self.response.error(f"Both parameters 'subject_curie' and 'object_curie' are not provided. Please provide the curie for either one of them")
return self.response
qedges = message.query_graph.edges


else:
self.response.error(f"The 'query_graph' is detected. One of 'subject_qnode_id' or 'object_qnode_id' should be specified.")

if self.parameters['regulation_type'] == 'increase':
edge_qualifier_direction = 'increased'
else:
edge_qualifier_direction = 'decreased'
edge_qualifier_list = [
Qualifier(qualifier_type_id='biolink:object_aspect_qualifier', qualifier_value='activity_or_abundance'),
Qualifier(qualifier_type_id='biolink:object_direction_qualifier', qualifier_value=edge_qualifier_direction)]

for qedge in qedges:
edge = message.query_graph.edges[qedge]
edge.knowledge_type = "inferred"
edge.predicates = ["biolink:affects"]
edge.qualifier_constraints = [QConstraint(qualifier_set=edge_qualifier_list)]


else:
if 'subject_curie' in parameters or 'object_curie' in parameters:
Expand Down Expand Up @@ -763,6 +784,7 @@ def __chemical_gene_regulation_graph_expansion(self, describe=False):

iu = InferUtilities()
qedge_id = self.parameters.get('qedge_id')

self.response, self.kedge_global_iter, self.qedge_global_iter, self.qnode_global_iter, self.option_global_iter = iu.genrete_regulate_subgraphs(self.response, None, normalized_object_curie, top_predictions, top_paths, qedge_id, self.parameters['regulation_type'], self.kedge_global_iter, self.qedge_global_iter, self.qnode_global_iter, self.option_global_iter)

return self.response
Expand Down
3 changes: 2 additions & 1 deletion code/ARAX/ARAXQuery/ARAX_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def track_query_finish(self):
if hasattr(self.response, 'job_id'):
query_tracker.update_tracker_entry(self.response.job_id, attributes)
else:
eprint("*******ERROR: self.response has no job_id attr! E275")
# Sometimes we finish without a job_id having been created, and that's okay
pass



Expand Down
19 changes: 10 additions & 9 deletions code/ARAX/ARAXQuery/ARAX_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from openapi_server.models.edge import Edge
from openapi_server.models.attribute import Attribute

edge_confidence_manual_agent = 0.999
edge_confidence_manual_agent = 0.99

def _get_nx_edges_by_attr(G: Union[nx.MultiDiGraph, nx.MultiGraph], key: str, val: str) -> Set[tuple]:
res_set = set()
Expand All @@ -33,8 +33,8 @@ def _get_nx_edges_by_attr(G: Union[nx.MultiDiGraph, nx.MultiGraph], key: str, va

def _get_query_graph_networkx_from_query_graph(query_graph: QueryGraph) -> nx.MultiDiGraph:
query_graph_nx = nx.MultiDiGraph()
query_graph_nx.add_nodes_from([key for key,node in query_graph.nodes.items()])
edge_list = [[edge.subject, edge.object, key, {'weight': 0.0}] for key,edge in query_graph.edges.items()]
query_graph_nx.add_nodes_from([key for key, node in query_graph.nodes.items() if 'creative_DTD_qnode' not in key and 'creative_CRG_qnode' not in key])
edge_list = [[edge.subject, edge.object, key, {'weight': 0.0}] for key,edge in query_graph.edges.items() if 'creative_DTD_qedge' not in key and 'creative_CRG_qedge' not in key]
query_graph_nx.add_edges_from(edge_list)
return query_graph_nx

Expand Down Expand Up @@ -124,8 +124,9 @@ def _get_weighted_graph_networkx_from_result_graph(kg_edge_id_to_edge: Dict[str,
qg_edge_key_to_edge_tuple = {edge_tuple[2]: edge_tuple for edge_tuple in qg_edge_tuples}
for analysis in result.analyses: # For now we only ever have one Analysis per Result
for qedge_key, edge_binding_list in analysis.edge_bindings.items():
qedge_tuple = qg_edge_key_to_edge_tuple[qedge_key]
res_graph[qedge_tuple[0]][qedge_tuple[1]][qedge_tuple[2]]['weight'] = _calculate_final_result_score(kg_edge_id_to_edge, edge_binding_list)
if 'creative_DTD_qedge' not in qedge_key and 'creative_CRG_qedge' not in qedge_key:
qedge_tuple = qg_edge_key_to_edge_tuple[qedge_key]
res_graph[qedge_tuple[0]][qedge_tuple[1]][qedge_tuple[2]]['weight'] = _calculate_final_result_score(kg_edge_id_to_edge, edge_binding_list)

return res_graph

Expand Down Expand Up @@ -187,7 +188,7 @@ def _score_networkx_graphs_by_max_flow(result_graphs_nx: List[Union[nx.MultiDiGr
capacity="weight"))
max_flow_value = 0.0
if len(max_flow_values_for_node_pairs) > 0:
max_flow_value = sum(max_flow_values_for_node_pairs)/float(len(max_flow_values_for_node_pairs))
max_flow_value = _calculate_final_individual_edge_confidence(0, max_flow_values_for_node_pairs)
else:
max_flow_value = 1.0
max_flow_values.append(max_flow_value)
Expand All @@ -209,7 +210,7 @@ def _score_networkx_graphs_by_longest_path(result_graphs_nx: List[Union[nx.Multi
adj_matrix_power = np.linalg.matrix_power(adj_matrix, max_path_len)/math.factorial(max_path_len)
score_list = [adj_matrix_power[map_node_name_to_index[node_i],
map_node_name_to_index[node_j]] for node_i, node_j in pairs_with_max_path_len]
result_score = np.mean(score_list)
result_score = _calculate_final_individual_edge_confidence(0, score_list)
result_scores.append(result_score)
return result_scores

Expand Down Expand Up @@ -365,7 +366,7 @@ def edge_attribute_score_combiner(self, edge_key, edge):
elif 'infores' in edge_key.split('--')[-1]: # default score for other data sources
base = edge_default_base
else: # virtual edges or inferred edges
base = 0 # no base score for these edges. Its core is based on
base = 0 # no base score for these edges. Its score is based on its attribute scores.

if edge.attributes is not None:
for edge_attribute in edge.attributes:
Expand Down Expand Up @@ -454,7 +455,7 @@ def edge_attribute_publication_normalizer(self, attribute_type_id: str, edge_att
pub_value = np.log(n_publications)
max_value = 1.0
curve_steepness = 3.16993
logistic_midpoint = 1.38629
logistic_midpoint = 1.60943 # log(5) = 1.60943 meaning having 5 publications is a mid point
normalized_value = max_value / float(1 + np.exp(-curve_steepness * (pub_value - logistic_midpoint)))
return normalized_value

Expand Down
5 changes: 1 addition & 4 deletions code/ARAX/ARAXQuery/Expand/kp_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ def _load_cached_kp_info(self) -> tuple:
self.log.error(f"Failed to load KP info caches due to {e}", error_code="LoadKPCachesFailed")
return None, None, None, None

# Record None URLs for our local KPs
allowed_kp_urls = smart_api_info["allowed_kp_urls"]

return (meta_map, allowed_kp_urls, smart_api_info["kps_excluded_by_version"],
return (meta_map, smart_api_info["allowed_kp_urls"], smart_api_info["kps_excluded_by_version"],
smart_api_info["kps_excluded_by_maturity"])

def get_kps_for_single_hop_qg(self, qg: QueryGraph) -> Optional[Set[str]]:
Expand Down
25 changes: 19 additions & 6 deletions code/ARAX/ARAXQuery/Expand/trapi_querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,8 @@ async def answer_one_hop_query_async(self, query_graph: QueryGraph,
# Patch to address lack of answers from KG2 for treats queries after treats refactor #2328
if alter_kg2_treats_edges and self.kp_infores_curie == "infores:rtx-kg2":
for qedge in qg_copy.edges.values(): # Note there's only ever one qedge per QG here
qedge.predicates = list(set(qedge.predicates).union({"biolink:treats_or_applied_or_studied_to_treat",
"biolink:applied_to_treat",
"biolink:studied_to_treat"}))
log.info(f"For querying infores:rtx-kg2, edited {qedge_key} to use higher treats-type predicates: "
qedge.predicates = list(set(qedge.predicates).union({"biolink:treats_or_applied_or_studied_to_treat"}))
log.info(f"For querying infores:rtx-kg2, edited {qedge_key} to use higher treats-type predicate: "
f"{qedge.predicates}")

# Answer the query using the KP and load its answers into our object model
Expand Down Expand Up @@ -368,9 +366,20 @@ def _load_kp_json_response(self, json_response: dict, qg: QueryGraph) -> QGOrgan
# Build a map that indicates which qnodes/qedges a given node/edge fulfills
kg_to_qg_mappings, query_curie_mappings = self._get_kg_to_qg_mappings_from_results(kp_message.results, qg)

# Populate our final KG with the returned nodes and edges
# Populate our final KG with the returned edges
returned_edge_keys_missing_qg_bindings = set()
nodes_dict = kp_message.knowledge_graph.nodes
for returned_edge_key, returned_edge in kp_message.knowledge_graph.edges.items():
# Catch invalid subject/object
if not returned_edge.subject or not returned_edge.object:
self.log.warning(f"{self.kp_infores_curie}: Edge has empty subject/object, skipping. "
f"subject: '{returned_edge.subject}', object: '{returned_edge.object}'")
continue
if returned_edge.subject not in nodes_dict or returned_edge.object not in nodes_dict:
self.log.warning(f"{self.kp_infores_curie}: Edge is an orphan, skipping. "
f"subject: '{returned_edge.subject}', object: '{returned_edge.object}'")
continue

arax_edge_key = self._get_arax_edge_key(returned_edge) # Convert to an ID that's unique for us

# Put in a placeholder for missing required attribute fields to try to keep our answer TRAPI-compliant
Expand Down Expand Up @@ -399,9 +408,13 @@ def _load_kp_json_response(self, json_response: dict, qg: QueryGraph) -> QGOrgan
self.log.warning(f"{self.kp_infores_curie}: {len(returned_edge_keys_missing_qg_bindings)} edges in the KP's answer "
f"KG have no bindings to the QG: {returned_edge_keys_missing_qg_bindings}")

# Populate our final KG with the returned nodes
returned_node_keys_missing_qg_bindings = set()
for returned_node_key, returned_node in kp_message.knowledge_graph.nodes.items():
if returned_node_key not in kg_to_qg_mappings['nodes']:
if not returned_node_key:
self.log.warning(f"{self.kp_infores_curie}: Node has empty ID, skipping. Node key is: "
f"'{returned_node_key}'")
elif returned_node_key not in kg_to_qg_mappings['nodes']:
returned_node_keys_missing_qg_bindings.add(returned_node_key)
else:
for qnode_key in kg_to_qg_mappings['nodes'][returned_node_key]:
Expand Down
Loading

0 comments on commit 5fcd1a2

Please sign in to comment.