-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_train.py
137 lines (96 loc) · 3.78 KB
/
model_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Preparation and Preprocessing
# Import libraries
from final_model_config import *
from dataloader import *
import torch
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
from utils import *
from segmentation_models_pytorch import utils
def model_train():
# Paths to folders containing training/validation images and masks
x_train_dir = Final_Config.INPUT_IMG_DIR + '/train'
y_train_dir = Final_Config.INPUT_MASK_DIR + '/train'
x_val_dir = Final_Config.INPUT_IMG_DIR + '/val'
y_val_dir = Final_Config.INPUT_MASK_DIR + '/val'
# Functions for transfer learning
def freeze_encoder(model):
for child in model.encoder.children():
for param in child.parameters():
param.requires_grad = False
return
def unfreeze(model):
for child in model.children():
for param in child.parameters():
param.requires_grad = True
return
model = Final_Config.MODEL
freeze_encoder(model)
# Create training and validation datasets and dataloaders with augmentations and proper preprocessing.
# If no augmentations are to be used, set augmentation to None
train_dataset = Dataset(
x_train_dir,
y_train_dir,
augmentation=get_training_augmentation(),
preprocessing=get_preprocessing(Final_Config.PREPROCESS)
)
val_dataset = Dataset(
x_val_dir,
y_val_dir,
preprocessing=get_preprocessing(Final_Config.PREPROCESS)
)
train_loader = DataLoader(train_dataset, batch_size=Final_Config.TRAIN_BATCH_SIZE, shuffle=True, num_workers=0)
train_loader = DataLoader(train_dataset, batch_size=Final_Config.TRAIN_BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=Final_Config.VAL_BATCH_SIZE, shuffle=False, num_workers=0)
# Create epoch runners to iterating over dataloader`s samples.
train_epoch = utils.train.TrainEpoch(
model,
loss=Final_Config.LOSS,
metrics=Final_Config.METRICS,
optimizer=Final_Config.OPTIMIZER,
device=Final_Config.DEVICE,
verbose=True,
)
valid_epoch = utils.train.ValidEpoch(
model,
loss=Final_Config.LOSS,
metrics=Final_Config.METRICS,
device=Final_Config.DEVICE,
verbose=True,
)
# Train model and save weights
max_score = 0
# Lists to keep track of losses and accuracies.
train_acc = []
train_loss = []
val_acc = []
val_loss = []
for i in range(0, Final_Config.EPOCHS):
print('\nEpoch: {}'.format(i))
train_logs = train_epoch.run(train_loader)
val_logs = valid_epoch.run(val_loader)
# Print and log F1-score
print(train_logs['fscore'])
print(val_logs['fscore'])
train_acc.append(train_logs['fscore'])
val_acc.append(val_logs['fscore'])
# Print and log loss
print(train_logs['CE_Dice'])
print(val_logs['CE_Dice'])
train_loss.append(train_logs['CE_Dice'])
val_loss.append(val_logs['CE_Dice'])
# do something (save model, change lr, etc.)
if max_score < val_logs['fscore']:
max_score = val_logs['fscore']
torch.save(model, Final_Config.WEIGHT_DIR + '.pth')
print('Model saved!')
# If desired, the below code adds in learning rate decay.
if i == 35:
Final_Config.OPTIMIZER.param_groups[0]['lr'] = 1e-5
print('Decrease decoder learning rate to 1e-5!')
# Save the loss and accuracy plots.
save_plots(
train_acc, val_acc, train_loss, val_loss,
Final_Config.PLOT_DIR + '_accuracy.png',
Final_Config.PLOT_DIR + '_loss.png',
)