-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
74 lines (57 loc) · 1.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.utils
import torch.utils.data
import torch.nn.functional as F
def batch2RNNinput(x_batch):
"""Adapt batch such to be ready to packed during RNN forward pass"""
x_batch = x_batch.sort(axis=0)[0]
x_batch = torch.flip(x_batch, dims=[1, 0])
x_batch = x_batch.T
x_data = x_batch[1:]
x_lengths = x_batch[:1].squeeze(0)
return x_data, x_lengths
def evaluate(model, eval_set, device, mode):
model.eval()
with torch.no_grad():
tot_loss = 0.0
tot_accuracy = 0.0
for x, y in eval_set:
x = x.long().to(device)
y = y.long().to(device)
if mode == 'rnn':
x_data, x_len = batch2RNNinput(x)
if 0. in x_len:
continue
out = model(x_data, x_len)
else:
out = model(x)
tot_loss += F.cross_entropy(out, y, reduction='sum').item()
predictions = torch.max(out, 1)[1]
tot_accuracy += torch.sum(predictions == y).item()
loss = tot_loss / len(eval_set.dataset)
acc = tot_accuracy / len(eval_set.dataset)
return loss, acc
def train(model, train_set, optimizer, device, mode):
model.train()
tot_loss = 0.0
tot_accuracy = 0.0
for x, y in train_set:
x = x.long().to(device)
y = y.long().to(device)
model.zero_grad()
if mode == 'rnn':
x_data, x_len = batch2RNNinput(x)
if 0. in x_len:
continue
out = model(x_data, x_len)
else:
out = model(x)
loss = F.cross_entropy(out, y, reduction='sum')
tot_loss += loss.item()
predictions = torch.max(out, 1)[1]
tot_accuracy += torch.sum(predictions == y).item()
loss.backward()
optimizer.step()
loss = tot_loss / len(train_set.dataset)
acc = tot_accuracy / len(train_set.dataset)
return loss, acc