-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_sim.py
executable file
·147 lines (109 loc) · 5.24 KB
/
run_sim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Main script for performing model fitting on synthetic data
Copyright 2023, Jiahe Lin, Huitian Lei and George Michailidis
All Rights Reserved
Lin, Lei and Michailidis assert copyright ownership of this code base and its derivative
works. This copyright statement should not be removed or edited.
-----do not edit anything above this line---
"""
import os
import sys
print(f'python version={".".join(map(str,sys.version_info[:3]))}')
print(f'current working dir={os.getcwd()}')
import yaml
import importlib
import argparse
import pickle
import json
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from utils import Evaluator
################################################################################
## modify the list here to select which prior-key settings to run
## in the case where it's empty, it corresponds to running the no-prior case
_PRIORKEYs = [] #[0.10, 0.20, 0.50]
################################################################################
parser = argparse.ArgumentParser(description='')
parser.add_argument('--ds_str', type=str, help='dataset to run',default='ds1')
parser.add_argument('--replica_id', type=int, help='replica id',default=0)
parser.add_argument('--train_size', type=int, help='sample size used for model training',default=200)
parser.add_argument('--standardize',help='whether to standardize the data',action='store_true')
parser.add_argument('--report',help='whether to report metrics',action='store_true')
def main():
global args
args = parser.parse_args()
_CONFIG = os.path.join('configs',f'{args.ds_str}.yaml')
config_key = f'{args.ds_str}-{args.train_size}' + ('' if not args.standardize else '-standardize')
print(f'===========================')
print(f'* ds_str={args.ds_str}; train_size={args.train_size}; standardize={args.standardize}; config_file={_CONFIG}, config_key={config_key}')
print(f'===========================')
with open(_CONFIG) as f:
meta_config = yaml.safe_load(f)
assert config_key in meta_config, f'{config_key} missing from config file {_CONFIG}'
args.configs = meta_config['default'].copy()
args.configs.update(meta_config[config_key])
models = importlib.import_module('src')
svarClass = getattr(models,args.configs['model_class'])
svar = svarClass(tau = args.configs['tau'],
rho = args.configs['rho'],
max_admm_iter = args.configs['max_admm_iter'],
admm_tol = args.configs['admm_tol'],
verbose = 0,
tol = args.configs['tol'],
max_epoch = args.configs['max_epoch'],
threshold_A = args.configs['threshold_A'],
threshold_B = args.configs['threshold_B'],
SILENCE = False)
with open(f'data/sim/{args.ds_str}/graph_info.pickle','rb') as handle:
graph_info = pickle.load(handle)
with open(f'data/sim/{args.ds_str}/data.pickle','rb') as handle:
data = pickle.load(handle)
xdata = data[args.replica_id][:args.train_size]
if args.standardize:
scaler = StandardScaler()
xdata = scaler.fit_transform(xdata)
for prior_key in [0.0] + _PRIORKEYs:
print('################')
if prior_key == 0.0:
print(f'## no priors')
A_NZ = None
else:
print(f'## {prior_key*100:.0f}% priors')
A_NZ = graph_info['prior_clean'][prior_key]
print('################')
out = svar.fitSVAR(xdata,
q=args.configs['nlags'],
mu_A=args.configs['mu_A'],
mu_B=args.configs['mu_B'],
mu_B_refit=args.configs.get('mu_B_refit',None),
A_NZ=A_NZ)
x_forecast = svar.forecast(xdata, out['A'], out['B'], horizon=1)
if args.standardize:
x_forecast = scaler.inverse_transform(x_forecast)
x_forecast_actual = np.expand_dims(data[args.replica_id][args.train_size], axis=0)
if args.report:
reports_skeleton = get_skeleton_report(graph_info, out)
print(reports_skeleton)
reports_x_forecast = get_x_report_forecast(x_forecast_actual, x_forecast)
print(reports_x_forecast)
return 0
def get_skeleton_report(graph_info, out):
evaluator = Evaluator()
reports_skeleton = []
## A
report = evaluator.report(graph_info['A'], out['A'])
report['key'] = 'A'
reports_skeleton.append(report)
## B
for lag_id in range(graph_info['B'].shape[-1]):
report = evaluator.report(graph_info['B'][:,:,lag_id], out['B'][:,:,lag_id])
report['key'] = f'B_{lag_id+1}'
reports_skeleton.append(report)
return reports_skeleton
def get_x_report_forecast(x_forecast_actual, x_forecast):
assert x_forecast_actual.shape[0] == 1
rmse = mean_squared_error(x_forecast_actual.reshape(-1,1), x_forecast.reshape(-1,1), squared=False)/np.linalg.norm(x_forecast_actual.reshape(-1,1))
return {'forecast_l2': round(rmse,3)}
if __name__ == "__main__":
main()