-
Notifications
You must be signed in to change notification settings - Fork 42
/
classifier_train.py
61 lines (50 loc) · 2.08 KB
/
classifier_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from data import Hand
from utils import Trainer
from network import resnet34, resnet101
# Hyper-params
data_root = './data/'
model_path = './models/'
batch_size = 60 # batch_size per GPU, if use GPU mode; resnet34: batch_size=120
num_workers = 2
init_lr = 0.01
lr_decay = 0.8
momentum = 0.9
weight_decay = 0.000
nesterov = True
# Set Training parameters
params = Trainer.TrainParams()
params.max_epoch = 1000
params.criterion = nn.CrossEntropyLoss()
params.gpus = [0] # set 'params.gpus=[]' to use CPU mode
params.save_dir = model_path
params.ckpt = None
params.save_freq_epoch = 100
# load data
print("Loading dataset...")
train_data = Hand(data_root,train=True)
val_data = Hand(data_root,train=False)
batch_size = batch_size if len(params.gpus) == 0 else batch_size*len(params.gpus)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
print('train dataset len: {}'.format(len(train_dataloader.dataset)))
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
print('val dataset len: {}'.format(len(val_dataloader.dataset)))
# models
# model = resnet34(pretrained=False, modelpath=model_path, num_classes=1000) # batch_size=120, 1GPU Memory < 7000M
# model.fc = nn.Linear(512, 6)
model = resnet101(pretrained=False, modelpath=model_path, num_classes=1000) # batch_size=60, 1GPU Memory > 9000M
model.fc = nn.Linear(512*4, 6)
# optimizer
trainable_vars = [param for param in model.parameters() if param.requires_grad]
print("Training with sgd")
params.optimizer = torch.optim.SGD(trainable_vars, lr=init_lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov)
# Train
params.lr_scheduler = ReduceLROnPlateau(params.optimizer, 'min', factor=lr_decay, patience=10, cooldown=10, verbose=True)
trainer = Trainer(model, params, train_dataloader, val_dataloader)
trainer.train()