-
Notifications
You must be signed in to change notification settings - Fork 166
/
learning_curve.py
73 lines (56 loc) · 1.98 KB
/
learning_curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import seaborn as sns
from tst import Transformer
from tst.loss import OZELoss
from src.dataset import OzeDataset
from src.utils import visual_sample, compute_loss
from src.utils import compute_loss, fit, Logger, kfold, leargnin_curve
# Search parameters
PARTS = 8
VALIDATION_SPLIT = 0.3
# Training parameters
DATASET_PATH = 'datasets/dataset_random.npz'
BATCH_SIZE = 8
NUM_WORKERS = 4
LR = 2e-4
EPOCHS = 5
# Model parameters
d_model = 32 # Lattent dim
q = 8 # Query size
v = 8 # Value size
h = 2 # Number of heads
N = 2 # Number of encoder and decoder to stack
attention_size = 24 # Attention window size
dropout = 0.2 # Dropout rate
pe = None # Positional encoding
chunk_mode = None
d_input = 38 # From dataset
d_output = 8 # From dataset
# Config
sns.set()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
# Load dataset
ozeDataset = OzeDataset(DATASET_PATH)
# Load network
# Load transformer with Adam optimizer and MSE loss function
loss_function = OZELoss(alpha=0.3)
logger = Logger('learningcurve_log.csv')
learningcurveIterator = leargnin_curve(ozeDataset, n_part=PARTS, validation_split=VALIDATION_SPLIT,
batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
with tqdm(total=PARTS*EPOCHS) as pbar:
for dataloader_train, dataloader_val in learningcurveIterator:
# Load transformer with Adam optimizer and MSE loss function
net = Transformer(d_input, d_model, d_output, q, v, h, N, attention_size=attention_size,
dropout=dropout, chunk_mode=chunk_mode, pe=pe).to(device)
optimizer = optim.Adam(net.parameters(), lr=LR)
# Fit model
loss = fit(net, optimizer, loss_function, dataloader_train,
dataloader_val, epochs=EPOCHS, pbar=pbar, device=device)
# Log
logger.log(loss=loss)