-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_rc_by_gcn.py
136 lines (108 loc) · 5.24 KB
/
train_rc_by_gcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import tqdm
import sys
sys.path.append('..') # import the upper directory of the current file into the search path
from .preprocess.preprocessing import load_pkl_data, SelfStandardScaler
from .utils.cfg import get_default_args
from .utils.utils import *
from .training_funcs import *
def main(args):
# Load data
model_dir, ret_path = load_path(args)
data_dict, data_num = load_pkl_data(args.data_path)
feature_list = data_dict["features"]
label_list = data_dict["labels"]
train_idx_list = data_dict["train_idx"]
test_idx_list = data_dict["test_idx"]
adj_mat1_list = data_dict["adj_mat1"]
adj_mat2_list = data_dict["adj_mat2"]
# for generating test results
gauge_list = data_dict["gauges"]
timestamp_list = data_dict["timestamps"]
if args.partial:
idx_s = args.start_idx
idx_e = min(args.end_idx, data_num)
else:
idx_s = 0
idx_e = data_num
i_iter = tqdm.tqdm(range(idx_s, idx_e),
desc="Processing: ",
total=idx_e-idx_s,
bar_format="{l_bar}{r_bar}")
test_labels_list, test_preds_list = [], [[], []]
for i in i_iter:
features, labels, idx_train, idx_test = feature_list[i], label_list[i], train_idx_list[i], test_idx_list[i]
adj_mat1, adj_mat2 = adj_mat1_list[i], adj_mat2_list[i]
timestamp = timestamp_list[i][0]
rain_scaler = SelfStandardScaler(mean=labels[idx_train].mean(),
std=labels[idx_train].std()) # use stats of training nodes
features = features[:, 0] # only rain values
nom_features = rain_scaler.transform(features) # standardize features
# nom_features[idx_test] = 0 # fixed: don't need reset to 0
nom_labels = rain_scaler.transform(labels) # standardize labels
# For the first round, load trained models directly
_model_dir = args.reload_path + "/model"
train_mse, med_rain_field, preds = run_one_graph(args, timestamp, adj_mat1, adj_mat2, nom_features, nom_labels,
idx_train, idx_test, _model_dir, round_num=None, reload=True)
preds = rain_scaler.inverse_transform(preds)
test_preds_list[0].append(preds[idx_test]) # predictions before correction
error_arr = labels - preds
error_labels = error_arr.copy()
error_arr[idx_test] = 0 # just can get error of training nodes
err_scaler = SelfStandardScaler(mean=error_labels[idx_train].mean(),
std=error_labels[idx_train].std()) # use stats of training nodes
nom_error_arr = err_scaler.transform(error_arr) # standardize features
# nom_error_arr[idx_test] = 0 # fixed: don't need reset to 0
nom_error_labels = err_scaler.transform(error_labels) # standardize labels
train_mse_e, med_rain_field_e, preds_e = run_one_graph(args, timestamp, adj_mat1, adj_mat2, nom_error_arr,
nom_error_labels, idx_train, idx_test, model_dir,
round_num=2)
preds_e = err_scaler.inverse_transform(preds_e)
preds = preds + preds_e
test_preds = preds[idx_test]
test_preds_list[1].append(test_preds)
test_labels_list.append(labels[idx_test])
test_gauge_list = gauge_list[idx_s: idx_e]
test_timestamp_list = timestamp_list[idx_s: idx_e]
save_csv_results(ret_path, test_timestamp_list, test_gauge_list, test_labels_list, test_preds_list, multi_preds=True)
if __name__ == '__main__':
parser = get_default_args()
args = parser.parse_args()
args.out_dir = "./output/GSI-RC-G"
if args.dataset.lower() == "hk":
args.paras_num = 1
prefix = "HK_data"
args.adj_type = "idw_power2_50th"
args.data_dir = f"{args.data_dir}/HK_123_Data/pkl_data"
args.reload_path = f"./output/GSI/HK_data/{args.adj_type}"
elif args.dataset.lower() == "bw":
args.paras_num = 2
prefix = "BW_data"
args.adj_type = "idw_power2_75th"
args.data_dir = f"{args.data_dir}/BW_132_Data/pkl_data"
args.reload_path = f"./output/GSI/BW_data/{args.adj_type}"
else:
raise NotImplementedError
if args.paras_num == 1:
# hyper-parameters for HK dataset
args.lr = 0.01242280373341682
args.weight_decay = 3.0189717208257073e-06
args.dropout = 0.3871241027778284
args.hidden = 8
args.nb_heads = 16
elif args.paras_num == 2:
# hyper-parameters for BW dataset
args.lr = 0.0030759392298867283
args.weight_decay = 4.540839696209309e-05
args.dropout = 0.3514742622380771
args.hidden = 4
args.nb_heads = 4
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
init_seeds(args)
args.data_path = "{}/{}_{}.pkl".format(args.data_dir, prefix, args.adj_type)
args.out_dir = f"{args.out_dir}/{prefix}/" + args.reload_path.split("/")[-1]
os.makedirs(args.out_dir, exist_ok=True)
save_args(args.__dict__, args.out_dir)
start_time = time.time()
main(args)
run_time = round((time.time() - start_time) / 3600, 2) # hour
save_running_time(args.out_dir, run_time)