-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
61 lines (48 loc) · 1.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import numpy as np
import hydra
from omegaconf import DictConfig, OmegaConf
from elastic_warping_vis.utils import create_directory, load_data
from elastic_warping_vis.draw_functions import draw_elastic, draw_elastic_gif
@hydra.main(config_name="config_hydra.yaml", config_path="config")
def main(args: DictConfig):
with open("config_hydra.yaml", "w") as f:
OmegaConf.save(args, f)
output_dir = args.output_dir
create_directory(output_dir)
dataset = args.dataset
output_dir_dataset = os.path.join(output_dir, dataset)
create_directory(output_dir_dataset)
X, y, is_classif = load_data(
dataset_name=dataset, split=args.split, znormalize=args.znormalize
)
if is_classif:
ts1 = X[y == args.class_x][
np.random.randint(low=0, high=len(X[y == args.class_x]), size=1)[0]
]
ts2 = X[y == args.class_y][
np.random.randint(low=0, high=len(X[y == args.class_y]), size=1)[0]
]
else:
ts1 = X[np.random.randint(low=0, high=len(X), size=1)[0]]
ts2 = X[np.random.randint(low=0, high=len(X), size=1)[0]]
draw_elastic(
x=ts1,
y=ts2,
output_dir=output_dir_dataset,
figsize=args.figsize,
metric=args.metric,
metric_params=args.metric_params,
show_warping_connections=args.show_warping,
)
draw_elastic_gif(
output_dir=output_dir_dataset,
x=ts1,
y=ts2,
figsize=args.figsize,
fontsize=10,
metric_params=args.metric_params,
metric=args.metric,
)
if __name__ == "__main__":
main()