This repository has been archived by the owner on Mar 19, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 118
/
main_predict.py
80 lines (59 loc) · 2.49 KB
/
main_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import json
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.externals import joblib
from modules import Encoder, Decoder
from utils import numpy_to_tvar
import utils
from custom_types import TrainData
from constants import device
def preprocess_data(dat, col_names, scale) -> TrainData:
proc_dat = scale.transform(dat)
mask = np.ones(proc_dat.shape[1], dtype=bool)
dat_cols = list(dat.columns)
for col_name in col_names:
mask[dat_cols.index(col_name)] = False
feats = proc_dat[:, mask]
targs = proc_dat[:, ~mask]
return TrainData(feats, targs)
def predict(encoder, decoder, t_dat, batch_size: int, T: int) -> np.ndarray:
y_pred = np.zeros((t_dat.feats.shape[0] - T + 1, t_dat.targs.shape[0]))
for y_i in range(0, len(y_pred), batch_size):
y_slc = slice(y_i, y_i + batch_size)
batch_idx = range(len(y_pred))[y_slc]
b_len = len(batch_idx)
X = np.zeros((b_len, T - 1, t_dat.feats.shape[1]))
y_history = np.zeros((b_len, T - 1, t_dat.targs.shape[0]))
for b_i, b_idx in enumerate(batch_idx):
idx = range(b_idx, b_idx + T - 1)
X[b_i, :, :] = t_dat.feats[idx, :]
y_history[b_i, :] = t_dat.targs[idx]
y_history = numpy_to_tvar(y_history)
_, input_encoded = encoder(numpy_to_tvar(X))
y_pred[y_slc] = decoder(input_encoded, y_history).cpu().data.numpy()
return y_pred
debug = False
save_plots = False
with open(os.path.join("data", "enc_kwargs.json"), "r") as fi:
enc_kwargs = json.load(fi)
enc = Encoder(**enc_kwargs)
enc.load_state_dict(torch.load(os.path.join("data", "encoder.torch"), map_location=device))
with open(os.path.join("data", "dec_kwargs.json"), "r") as fi:
dec_kwargs = json.load(fi)
dec = Decoder(**dec_kwargs)
dec.load_state_dict(torch.load(os.path.join("data", "decoder.torch"), map_location=device))
scaler = joblib.load(os.path.join("data", "scaler.pkl"))
raw_data = pd.read_csv(os.path.join("data", "nasdaq100_padding.csv"), nrows=100 if debug else None)
targ_cols = ("NDX",)
data = preprocess_data(raw_data, targ_cols, scaler)
with open(os.path.join("data", "da_rnn_kwargs.json"), "r") as fi:
da_rnn_kwargs = json.load(fi)
final_y_pred = predict(enc, dec, data, **da_rnn_kwargs)
plt.figure()
plt.plot(final_y_pred, label='Predicted')
plt.plot(data.targs[(da_rnn_kwargs["T"]-1):], label="True")
plt.legend(loc='upper left')
utils.save_or_show_plot("final_predicted_reloaded.png", save_plots)