-
Notifications
You must be signed in to change notification settings - Fork 17
/
trainer.py
91 lines (72 loc) · 3.36 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# -*- coding: utf-8 -*-
import torch
import numpy as np
from visualize import save_ratemaps
import os
class Trainer(object):
def __init__(self, options, model, trajectory_generator, restore=True):
self.options = options
self.model = model
self.trajectory_generator = trajectory_generator
lr = self.options.learning_rate
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
self.loss = []
self.err = []
# Set up checkpoints
self.ckpt_dir = os.path.join(options.save_dir, options.run_ID)
ckpt_path = os.path.join(self.ckpt_dir, 'most_recent_model.pth')
if restore and os.path.isdir(self.ckpt_dir) and os.path.isfile(ckpt_path):
self.model.load_state_dict(torch.load(ckpt_path))
print("Restored trained model from {}".format(ckpt_path))
else:
if not os.path.isdir(self.ckpt_dir):
os.makedirs(self.ckpt_dir, exist_ok=True)
print("Initializing new model from scratch.")
print("Saving to: {}".format(self.ckpt_dir))
def train_step(self, inputs, pc_outputs, pos):
'''
Train on one batch of trajectories.
Args:
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
pc_outputs: Ground truth place cell activations with shape
[batch_size, sequence_length, Np].
pos: Ground truth 2d position with shape [batch_size, sequence_length, 2].
Returns:
loss: Avg. loss for this training batch.
err: Avg. decoded position error in cm.
'''
self.model.zero_grad()
loss, err = self.model.compute_loss(inputs, pc_outputs, pos)
loss.backward()
self.optimizer.step()
return loss.item(), err.item()
def train(self, n_epochs: int = 1000, n_steps=10, save=True):
'''
Train model on simulated trajectories.
Args:
n_steps: Number of training steps
save: If true, save a checkpoint after each epoch.
'''
# Construct generator
gen = self.trajectory_generator.get_generator()
# tbar = tqdm(range(n_steps), leave=False)
for epoch_idx in range(n_epochs):
for step_idx in range(n_steps):
inputs, pc_outputs, pos = next(gen)
loss, err = self.train_step(inputs, pc_outputs, pos)
self.loss.append(loss)
self.err.append(err)
# Log error rate to progress bar
# tbar.set_description('Error = ' + str(np.int(100*err)) + 'cm')
print('Epoch: {}/{}. Step {}/{}. Loss: {}. Err: {}cm'.format(
epoch_idx, n_epochs, step_idx, n_steps,
np.round(loss, 2), np.round(100 * err, 2)))
if save:
# Save checkpoint
ckpt_path = os.path.join(self.ckpt_dir, 'epoch_{}.pth'.format(epoch_idx))
torch.save(self.model.state_dict(), ckpt_path)
torch.save(self.model.state_dict(), os.path.join(self.ckpt_dir,
'most_recent_model.pth'))
# Save a picture of rate maps
save_ratemaps(self.model, self.trajectory_generator,
self.options, step=epoch_idx)