-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
181 lines (142 loc) · 6.85 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import sys
#append this directory to make module work
sys.path.append(os.getcwd())
import argparse
import nonechucks as nc
import torch
import torch.nn as nn
import torch.optim as optim
# from torch.utils.data.dataloader import DataLoader
import pytorch_lightning as pl
from craft.datasets import loader
from craft import nn as nnc
from craft.trainer.task import TaskCRAFT
from craft.trainer import helper as trainer_helper
from craft.models.craft import CRAFT
from pathlib import Path
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='trainer craft with synthtext dataset ')
parser.add_argument('--resume', default=None, type=str,
help='Choose pth file to resume training')
parser.add_argument('--max_epoch', required=True, default=None,
type=int, help='How many epoch to run training')
parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float,
help='choose learning rate for optimizer, default value is 0.01')
parser.add_argument('--bsize', '--batch_size', default=8, type=int,
help='choose batch size for data loader, default value is 16')
parser.add_argument('--shuffle', default=True, type=bool,
help='choose to shuffle data or not, default value is True')
parser.add_argument('--num_workers', default=8, type=int,
help='how many workers to load for running dataset')
parser.add_argument('--dataset_path', required=True, default='/data/synthtext', type=str,
help='path to synthtext dataset')
parser.add_argument('--dataset_type', required=True, default='synthtext', type=str,
help='fill with synthtext or custom')
parser.add_argument('--image_size', default='224x224', type=str,
help='witdh and height of the image, default value is 224x224')
parser.add_argument('--num_gpus', default=1, type=int,
help='fill with zero to use cpu or fill with number 2 to use multigpu')
parser.add_argument('--log_freq', default=10, type=int,
help='show log every value, default value is 10')
parser.add_argument('--checkpoint_dir', default='checkpoints/', type=str,
help='checkpoint directory for saving progress')
parser.add_argument('--logs_dir', default='logs/', type=str,
help='directory logs for tensorboard callback')
args = parser.parse_args()
w, h = args.image_size.split('x')
w, h = int(w), int(h)
# hyper parameter
ROOT_PATH = args.dataset_path
DATA_TYPE = args.dataset_type
IMSIZE = (w, h)
BSIZE = args.bsize
SHUFFLE = args.shuffle
NWORKERS = args.num_workers
LRATE = args.lr
WDECAY = 0.002
MOMENTUM = 0.9
SCH_STEP_SIZE = 3
SCH_GAMMA = 0.1
MAX_EPOCHS = args.max_epoch
NUM_GPUS = args.num_gpus
LOG_FREQ = args.log_freq
SAVED_CHECKPOINT_PATH = args.checkpoint_dir
SAVED_LOGS_PATH = args.logs_dir
CHECKPOINT_RESUME = False
CHECKPOINT_PATH = None
WEIGHT_RESUME = False
WEIGHT_PATH = None
if args.resume:
fpath = Path(args.resume)
# print(fpath.suffix)
if fpath.is_file():
if fpath.suffix == '.ckpt':
# it means checkpoint of pytorch lightning
CHECKPOINT_RESUME = True
CHECKPOINT_PATH = str(fpath)
elif fpath.suffix == '.pth':
# it means pytorch file original from model
WEIGHT_RESUME = True
WEIGHT_PATH = str(fpath)
else:
raise NotImplemented(f'File with {fpath.suffix} is not implemented! '
f'make sure you load valid file with ckpt or pth extension!')
else:
raise IOError(f'Path that you specified is not valid pytorch or pytorch-lighning path!')
if DATA_TYPE == 'synthtext':
# trailoader and validloader
trainloader = loader.synthtext_trainloader(path=ROOT_PATH, batch_size=BSIZE,
shuffle=SHUFFLE, nworkers=NWORKERS)
validloader = loader.synthtext_validloader(path=ROOT_PATH, batch_size=BSIZE,
shuffle=False, nworkers=NWORKERS)
elif DATA_TYPE == 'custom':
trainloader = loader.custom_trainloader(path=ROOT_PATH, batch_size=BSIZE,
shuffle=SHUFFLE, nworkers=NWORKERS)
validloader = loader.custom_trainloader(path=ROOT_PATH, batch_size=BSIZE,
shuffle=False, nworkers=NWORKERS)
else:
raise NotImplemented(f'Only synthtext and custom type dataset are supported, '
f'for other type is not supported yet!')
# Model Preparation
if WEIGHT_RESUME:
print(f'Log:\tPretrain CRAFT model using weight from {WEIGHT_PATH}')
model = CRAFT(pretrained=True)
weights = torch.load(WEIGHT_PATH, map_location=torch.device('cpu'))
weights = trainer_helper.copy_state_dict(weights)
model.load_state_dict(weights)
trainer_helper.freeze_network(model)
trainer_helper.unfreeze_conv_cls_module(model)
else:
model = CRAFT(pretrained=True)
criterion = nnc.OHEMLoss()
optimizer = optim.SGD(model.parameters(), lr=LRATE, weight_decay=WDECAY, momentum=MOMENTUM)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=SCH_STEP_SIZE, gamma=SCH_GAMMA)
task = TaskCRAFT(model, criterion, optimizer, scheduler)
# DEFAULTS used by the Trainer
model_checkpoint = pl.callbacks.ModelCheckpoint(
dirpath=SAVED_CHECKPOINT_PATH,
filename='{epoch}-{val_loss:.2f}',
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min',
# prefix='craftnet'
)
tensorboard_logger = pl.loggers.TensorBoardLogger(SAVED_LOGS_PATH)
if CHECKPOINT_RESUME:
trainer = pl.Trainer(max_epochs=MAX_EPOCHS, gpus=NUM_GPUS,
logger=tensorboard_logger,
callbacks=[model_checkpoint],
log_every_n_steps=LOG_FREQ,
num_sanity_val_steps=0,
resume_from_checkpoint=CHECKPOINT_PATH)
else:
trainer = pl.Trainer(max_epochs=MAX_EPOCHS, gpus=NUM_GPUS,
logger=tensorboard_logger,
callbacks=[model_checkpoint],
log_every_n_steps=LOG_FREQ,
num_sanity_val_steps=0)
# start training the model
trainer.fit(task, trainloader, validloader)