Skip to content

Commit

Permalink
udpate code for the new xDTD rebuild #1967
Browse files Browse the repository at this point in the history
  • Loading branch information
chunyuma committed Jul 7, 2023
1 parent a372b67 commit b89b569
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 37 deletions.
26 changes: 13 additions & 13 deletions code/ARAX/ARAXQuery/ARAX_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def __init__(self):
self.xdtd_n_drugs_info = {
"is_required": False,
"examples": [5,15,25],
"default": 25,
"default": 50,
"type": "integer",
"description": "The number of drug nodes to return. If not provided defaults to 25. Considering the response speed, the maximum number of drugs returned is only allowed to be 25."
"description": "The number of drug nodes to return. If not provided defaults to 50. Considering the response speed, the maximum number of drugs returned is only allowed to be 50."
}
self.xdtd_n_paths_info = {
"is_required": False,
Expand Down Expand Up @@ -367,7 +367,7 @@ def __drug_treatment_graph_expansion(self, describe=False):
allowable_parameters = {'action': {'drug_treatment_graph_expansion'},
'node_curie': {'The node to predict drug treatments for.'},
'qedge_id': {'The edge to place the predicted mechanism of action on. If none is provided, the query graph must be empty and a new one will be inserted.'},
'n_drugs': {'The number of drugs to return. Defaults to 25. Maxiumum is only allowable to be 25.'},
'n_drugs': {'The number of drugs to return. Defaults to 50. Maxiumum is only allowable to be 50.'},
'n_paths': {'The number of paths connecting each drug to return. Defaults to 25. Maxiumum is only allowable to be 25.'}
}

Expand Down Expand Up @@ -400,11 +400,11 @@ def __drug_treatment_graph_expansion(self, describe=False):
self.response.error(f"The `n_drugs` value must be a positive integer. The provided value was {self.parameters['n_drugs']}.", error_code="ValueError")
if self.parameters['n_drugs'] <= 0:
self.response.error(f"The `n_drugs` value should be larger than 0. The provided value was {self.parameters['n_drugs']}.", error_code="ValueError")
if self.parameters['n_drugs'] > 25:
self.response.warning(f"The `n_drugs` value was set to {self.parameters['n_drugs']}, but the maximum allowable value is 25. Setting `n_drugs` to 25.")
self.parameters['n_drugs'] = 25
if self.parameters['n_drugs'] > 50:
self.response.warning(f"The `n_drugs` value was set to {self.parameters['n_drugs']}, but the maximum allowable value is 50. Setting `n_drugs` to 50.")
self.parameters['n_drugs'] = 50
else:
self.parameters['n_drugs'] = 25
self.parameters['n_drugs'] = 50

if 'n_paths' in self.parameters:
try:
Expand Down Expand Up @@ -448,7 +448,7 @@ def __drug_treatment_graph_expansion(self, describe=False):
continue

if len(top_drugs) == 0:
self.response.warning(f"Could not get predicted drugs for disease {preferred_curie}. Likely the model was not trained with this disease.")
self.response.warning(f"Could not get predicted drugs for disease {preferred_curie}. Likely the model was not trained with this disease. Or No predicted drugs for this disease with score >= 0.5.")
continue
if len(top_paths) == 0:
self.response.warning(f"Could not get any predicted paths for disease {preferred_curie}. Likely the model considers there is no reasonable path for this disease.")
Expand All @@ -463,7 +463,7 @@ def __drug_treatment_graph_expansion(self, describe=False):
top_drugs = top_drugs.iloc[:self.parameters['n_drugs'],:].reset_index(drop=True)
top_paths = {(row[0], row[2]):top_paths[(row[0], row[2])][:self.parameters['n_paths']] for row in top_drugs.to_numpy() if (row[0], row[2]) in top_paths}

# TRAPI-ifies the results of the model
# # TRAPI-ifies the results of the model
qedge_id = self.parameters.get('qedge_id')
self.response, self.kedge_global_iter, self.qedge_global_iter, self.qnode_global_iter, self.option_global_iter = iu.genrete_treat_subgraphs(self.response, top_drugs, top_paths, qedge_id, self.kedge_global_iter, self.qedge_global_iter, self.qnode_global_iter, self.option_global_iter)

Expand Down Expand Up @@ -513,8 +513,8 @@ def __chemical_gene_regulation_graph_expansion(self, describe=False):
'threshold': {"Threshold to filter the prediction probability. If not provided defaults to 0.5."},
'kp': {"KP to use in path extraction. If not provided defaults to 'infores:rtx-kg2'."},
'path_len': {"The length of paths for prediction. If not provided defaults to 2."},
'n_result_curies': {'The number of top predicted result nodes to return. Defaults to 20.'},
'n_paths': {'The number of paths connecting to each returned node. Defaults to 20.'}
'n_result_curies': {'The number of top predicted result nodes to return. Defaults to 10.'},
'n_paths': {'The number of paths connecting to each returned node. Defaults to 10.'}
}

# A little function to describe what this thing does
Expand Down Expand Up @@ -580,7 +580,7 @@ def __chemical_gene_regulation_graph_expansion(self, describe=False):
f"The `n_result_curies` value must be a positive integer. The provided value was {self.parameters['n_result_curies']}.",
error_code="ValueError")
else:
self.parameters['n_result_curies'] = 25
self.parameters['n_result_curies'] = 10

if 'n_paths' in self.parameters:
if isinstance(self.parameters['n_paths'], str):
Expand All @@ -593,7 +593,7 @@ def __chemical_gene_regulation_graph_expansion(self, describe=False):
f"The `n_paths` value must be a positive integer. The provided value was {self.parameters['n_paths']}.",
error_code="ValueError")
else:
self.parameters['n_paths'] = 25
self.parameters['n_paths'] = 10

if self.response.status != 'OK':
return self.response
Expand Down
4 changes: 2 additions & 2 deletions code/ARAX/ARAXQuery/Infer/scripts/ExplianableDTD_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,11 @@ def main():
return

print("==== Testing for search for top drugs by disease id ====", flush=True)
print(EDTDdb.get_top_drugs_for_disease('MONDO:0008753'))
print(EDTDdb.get_top_drugs_for_disease('MONDO:0005148'))
# print(EDTDdb.get_top_drugs_for_disease(["MONDO:0008753","MONDO:0005148","MONDO:0005155"]))

print("==== Testing for search for top paths by disease id ====", flush=True)
print(EDTDdb.get_top_paths_for_disease('MONDO:0008753'))
print(EDTDdb.get_top_paths_for_disease('MONDO:0005148'))
# print(EDTDdb.get_top_paths_for_disease(["MONDO:0008753","MONDO:0005148","MONDO:0005155"]))

####################################################################################################
Expand Down
11 changes: 6 additions & 5 deletions code/ARAX/ARAXQuery/Infer/scripts/infer_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,13 @@ def genrete_treat_subgraphs(self, response: ARAXResponse, top_drugs: pd.DataFram
# The x[0] is here since each element consists of the string path and a score we are currently ignoring the score
split_paths = [x[0].split("->") for x in paths]
for path in split_paths:
drug_name = path[0]
drug_curie = path[0]
n_elements = len(path)

edges_info = []
flag = False
for i in range(0,n_elements-2,2):
edge_info = xdtdmapping.get_edge_info(triple_name=(path[i],path[i+1],path[i+2]))
edge_info = xdtdmapping.get_edge_info(triple_id=(path[i],path[i+1],path[i+2]))
if len(edge_info) == 0:
flag = True
else:
Expand Down Expand Up @@ -337,7 +337,9 @@ def genrete_treat_subgraphs(self, response: ARAXResponse, top_drugs: pd.DataFram
path_added = True
if path_added:
treat_score = top_drugs.loc[top_drugs['drug_id'] == drug]["tp_score"].iloc[0]
essence_scores[drug_name] = treat_score
drug_node_info = xdtdmapping.get_node_info(node_id=drug_curie)
disease_node_info = xdtdmapping.get_node_info(node_id=disease_curie)
essence_scores[drug_node_info.name] = treat_score
edge_attribute_list = [
Attribute(original_attribute_name="defined_datetime", value=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), attribute_type_id="metatype:Datetime"),
Attribute(original_attribute_name=None, value=True, attribute_type_id="EDAM-DATA:1772", attribute_source="infores:arax", value_type_id="metatype:Boolean", value_url=None, description="This edge is a container for a computed value between two nodes that is not directly attachable to other edges."),
Expand All @@ -350,8 +352,7 @@ def genrete_treat_subgraphs(self, response: ARAXResponse, top_drugs: pd.DataFram
edge_predicate = "biolink:treats"
if hasattr(message.query_graph.edges[qedge_id], 'predicates') and message.query_graph.edges[qedge_id].predicates:
edge_predicate = message.query_graph.edges[qedge_id].predicates[0] # FIXME: better way to handle multiple predicates?
drug_node_info = xdtdmapping.get_node_info(node_name=drug_name)
disease_node_info = xdtdmapping.get_node_info(node_name=disease_name)

fixed_edge = Edge(predicate=edge_predicate, subject=drug_node_info.id, object=disease_node_info.id,
attributes=edge_attribute_list, sources=retrieval_source)
#fixed_edge.qedge_keys = ["treats"]
Expand Down
12 changes: 6 additions & 6 deletions code/ARAX/test/test_ARAX_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def test_xdtd_expand():
query = {
"nodes": {
"disease": {
"ids": ["MONDO:0004975"]
"ids": ["MONDO:0017979"]
},
"chemical": {
"categories": ["biolink:ChemicalEntity"]
Expand Down Expand Up @@ -1125,7 +1125,7 @@ def test_xdtd_different_categories():
query = {
"nodes": {
"disease": {
"ids": ["MONDO:0004975"]
"ids": ["MONDO:0005615"]
},
"chemical": {
"categories": ["biolink:Drug"]
Expand All @@ -1144,7 +1144,7 @@ def test_xdtd_different_categories():
query = {
"nodes": {
"disease": {
"ids": ["MONDO:0004975"],
"ids": ["MONDO:0005615"],
"categories": ["biolink:Disease"]
},
"chemical": {
Expand All @@ -1164,7 +1164,7 @@ def test_xdtd_different_categories():
query = {
"nodes": {
"disease": {
"ids": ["MONDO:0004975"],
"ids": ["MONDO:0005615"],
"categories": ["biolink:DiseaseOrPhenotypicFeature"]
},
"chemical": {
Expand All @@ -1187,7 +1187,7 @@ def test_xdtd_multiple_categories():
query = {
"nodes": {
"disease": {
"ids": ["MONDO:0004975"]
"ids": ["UMLS:C5419466"]
},
"chemical": {
"categories": ["biolink:Drug", "biolink:ChemicalMixture"]
Expand All @@ -1209,7 +1209,7 @@ def test_xdtd_different_predicates():
query = {
"nodes": {
"disease": {
"ids": ["UMLS:C4023597"]
"ids": ["UMLS:C5419466"]
},
"chemical": {
"categories": ["biolink:Drug", "biolink:ChemicalMixture"]
Expand Down
22 changes: 11 additions & 11 deletions code/ARAX/test/test_ARAX_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def _virtual_tester(message: Message, edge_predicate: str, relation: str, attrib
assert len(values) >= num_different_values


def test_xdtd_infer_alkaptonuria_1():
def test_xdtd_infer_diabetes_1():
query = {"operations": {"actions": [
"create_message",
"infer(action=drug_treatment_graph_expansion,node_curie=MONDO:0008753)",
"infer(action=drug_treatment_graph_expansion,node_curie=MONDO:0005148)",
"return(message=true, store=true)"
]}}
[response, message] = _do_arax_query(query)
Expand All @@ -91,10 +91,10 @@ def test_xdtd_infer_alkaptonuria_1():
assert len(message.query_graph.edges) == 1
assert len(message.results) > 0

def test_xdtd_infer_alkaptonuria_2():
def test_xdtd_infer_diabetes_2():
query = {"operations": {"actions": [
"create_message",
"infer(action=drug_treatment_graph_expansion,node_curie=MONDO:0008753,n_drugs=2,n_paths=15)",
"infer(action=drug_treatment_graph_expansion,node_curie=MONDO:0005148,n_drugs=2,n_paths=15)",
"return(message=true, store=true)"
]}}
[response, message] = _do_arax_query(query)
Expand All @@ -108,7 +108,7 @@ def test_xdtd_with_qg():
"message": {"query_graph": {
"nodes": {
"disease": {
"ids": ["MONDO:0004975"]
"ids": ["MONDO:0003912"]
},
"chemical": {
"categories": ["biolink:ChemicalEntity"]
Expand All @@ -125,7 +125,7 @@ def test_xdtd_with_qg():
}
},
"operations": {"actions": [
"infer(action=drug_treatment_graph_expansion,node_curie=MONDO:0008753,qedge_id=t_edge)",
"infer(action=drug_treatment_graph_expansion,node_curie=test_xdtd_with_qg,qedge_id=t_edge)",
"return(message=true, store=true)"
]}
}
Expand All @@ -141,7 +141,7 @@ def test_xdtd_with_qg2():
"message": {"query_graph": {
"nodes": {
"disease": {
"ids": ["MONDO:0004975"]
"ids": ["MONDO:0003912"]
},
"chemical": {
"categories": ["biolink:ChemicalEntity"]
Expand All @@ -158,7 +158,7 @@ def test_xdtd_with_qg2():
}
},
"operations": {"actions": [
"infer(action=drug_treatment_graph_expansion,node_curie=MONDO:0004975,qedge_id=t_edge)",
"infer(action=drug_treatment_graph_expansion,node_curie=MONDO:0003912,qedge_id=t_edge)",
"return(message=true, store=true)"
]}
}
Expand All @@ -174,7 +174,7 @@ def test_xdtd_with_qg3():
"message": {"query_graph": {
"nodes": {
"disease": {
"ids": ["MONDO:0004975"]
"ids": ["MONDO:0017979"]
},
"chemical": {
"categories": ["biolink:ChemicalEntity"]
Expand All @@ -191,7 +191,7 @@ def test_xdtd_with_qg3():
}
},
"operations": {"actions": [
"infer(action=drug_treatment_graph_expansion,node_curie=MONDO:0004975,qedge_id=t_edge,n_drugs=10,n_paths=10)",
"infer(action=drug_treatment_graph_expansion,node_curie=MONDO:0017979,qedge_id=t_edge,n_drugs=10,n_paths=10)",
"return(message=true, store=true)"
]}
}
Expand All @@ -206,7 +206,7 @@ def test_xdtd_with_only_qg():
"message": {"query_graph": {
"nodes": {
"disease": {
"ids": ["MONDO:0004975"]
"ids": ["MONDO:0003912"]
},
"chemical": {
"categories": ["biolink:ChemicalEntity"]
Expand Down

0 comments on commit b89b569

Please sign in to comment.