-
Notifications
You must be signed in to change notification settings - Fork 1
/
visualize.py
45 lines (36 loc) · 1.27 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from argparse import ArgumentParser
import torch
from omegaconf import OmegaConf
from fraud_detection import GAT, EllipticDataset, Trainer
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument(
"--config",
default="configs/elliptic_gat.yaml",
required=True,
help="Path to training config",
)
parser.add_argument(
"--step",
required=True,
help="The timestamp step to visualize predictions",
)
parser.add_argument(
"--weights_file",
default=None,
help="Path to PyTorch weights file. Grabs from config if not provided.",
)
args = parser.parse_args()
config_path = args.config
time_step = args.step
weights_file = args.weights_file
config = OmegaConf.load(config_path)
dataset = EllipticDataset(config.dataset)
config.model.input_dim = dataset.pyg_dataset().num_node_features
model = GAT(config.model)
if weights_file is None:
weights_file = f"weights/{config.name}.pt"
model.load_state_dict(torch.load(weights_file))
trainer = Trainer(config)
trainer.model = model.double().to(config.train.device)
trainer.visualize(dataset, time_step=time_step, save_to=f"visualizations/{config.name}/{time_step}.png")