-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
161 lines (140 loc) · 5.51 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import argparse
import os
import json
from metrics import *
from graph import qa_dataset
from config import DEVICE, URI
from torch_geometric.seed import seed_everything
from mlflow import set_tracking_uri, log_dict, start_run
from mlflow.pytorch import load_model
import ast
#! FILTER_PATH WORKS FOR UNIQUE DATASET...
#parser for all arguments!
parser = argparse.ArgumentParser(description='Testing query embeddings...')
#requirement arguments...
parser.add_argument("run_id",
type=str, help="Run id of a model")
parser.add_argument("test_dict",
type=str, help="Test dict containing results")
#optional requirements!
parser.add_argument("--N",
default=10,
type=int, help="hits@N N. Only used when hits@ is used as argument in metric")
parser.add_argument("--tests",
default='./datasets/FB15k_237/qa/test_qa.txt',
type=str, help="Path to test data")
parser.add_argument("--train_data",
default=None,
type=str, help="Path to train data")
parser.add_argument("--val_data",
default=None,
type=str, help="Path to val data")
parser.add_argument("--filtering",
default = False,
type=bool, help="Filter out true answers, that artificially lower the scores...")
parser.add_argument("--big",
default='10e5',
type=float, help="Value of mask, so as to filter out golden triples")
parser.add_argument("--batch_size",
default=128,
type=int, help="Test batch size")
parser.add_argument("--seed",
default=42,
type=int, help="Seed for randomness")
parser.add_argument("--filter_path",
default=True,
type=bool, help="Precomputed filter path. File is pickled...")
parser.add_argument("--all_tests",
action=argparse.BooleanOptionalAction,
help="Calculate also mean_rank and mrr tests!")
parser.add_argument("--all_scaled",
action=argparse.BooleanOptionalAction,
help="Calculate only mrrGrouped and hits@NGrouped")
#finds all arguments...
args = parser.parse_args()
SEED = args.seed
MODEL_URI = "runs:/"+args.run_id+"/model"
filtering = args.filtering
test_data = ast.literal_eval(args.tests)
N = args.N
batch_size = args.batch_size
#seeds
seed_everything(SEED)
set_tracking_uri(URI) #sets uri for mlflow!
#load model...
model = load_model(MODEL_URI)
#put model to device
model.to(DEVICE)
if filtering:
if not (args.train_data and args.val_data):
print("train data and val data REQUIRED when filtering!!!")
raise
#directory where qas are stored...
id_dir=os.path.dirname(args.train_data)
with open(os.path.join(id_dir, "info.json"), "r") as file:
info = json.load(file)
num_entities = info["num_entities"]
filter_path = os.path.join(os.path.dirname(args.train_data), "filter.pkl")
else:
filter = None
logs = {}
for i, test_file in enumerate(test_data):
print(f"Loading {test_file} ...")
test = qa_dataset(test_file) #get test data
if filtering:
if i == 0:
if args.filter_path:
print("Loading filter")
filter = Filter(None, None, test_file, num_entities, big=args.big, load_path=filter_path)
print("filter loaded!")
else:
print("creating filter...")
filter = Filter(args.train_data, args.val_data, test_file, num_entities, big = args.big)
print("filter made successfully!")
else:
# update to new test
print("updating filter...")
filter.change_test(test_file)
print("done!")
logs[test_file] = {}
if args.all_tests:
result1 = mean_rank(test, model, batch_size = batch_size, filter=filter, device=DEVICE)
logs[test_file]["mean_rank"] = {
"result": result1,
"N": None,
}
result6 = mean_rank_grouped(test, model, batch_size = batch_size, filter=filter, device=DEVICE)
logs[test_file]["mean_rank_grouped"] = {
"result": result6,
"N": None,
}
if not args.all_scaled:
result2 = hits_at_N(test, model, N=N, batch_size = batch_size, filter=filter, device=DEVICE)*100
logs[test_file]["hits@"] = {
"result": result2,
"N": N,
}
result3 = hits_at_N_Grouped(test, model, N=N, batch_size = batch_size, filter=filter, device=DEVICE)*100
logs[test_file]["hitsGrouped@"] = {
"result": result3,
"N": N,
}
if args.all_tests or args.all_scaled:
if not args.all_scaled:
result4 = mean_reciprocal_rank(test, model, batch_size = batch_size, filter=filter, device=DEVICE)*100
logs[test_file]["mrr"] = {
"result": result4,
"N": None,
}
result5 = mrr_Grouped(test, model, batch_size = batch_size, filter=filter, device=DEVICE)*100
logs[test_file]["mrrGrouped"] = {
"result": result5,
"N": None,
}
print(f"Finished {test_file}")
logs["utils"] = {}
logs["utils"]["filtering"] = filtering
logs["utils"]["train_data"] = args.train_data if (filtering) else None
logs["utils"]["val_data"] = args.val_data if (filtering) else None
with start_run(run_id=args.run_id):
log_dict(logs, artifact_file="tests/"+args.test_dict)