-
Notifications
You must be signed in to change notification settings - Fork 4
/
engine_pretrain.py
127 lines (102 loc) · 5.32 KB
/
engine_pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# ----------------------------------------------------------------------
# HAP: Structure-Aware Masked Image Modeling for Human-Centric Perception
# Written by Junkun Yuan (yuanjk0921@outlook.com)
# ----------------------------------------------------------------------
# Training file for pre-training
# ----------------------------------------------------------------------
# References:
# MALE: https://github.com/YanzuoLu/MALE
# ----------------------------------------------------------------------
import os
import sys
import time
import math
import json
import datetime
import torch
import torch.distributed as dist
from timm.utils import AverageMeter
from utils import adjust_learning_rate
def train_one_epoch(cfg, train_loader, model, optimizer, scaler, epoch, device, summary_writer, logger, **mask_pose):
batch_time = AverageMeter()
losses = AverageMeter()
if cfg.MODEL.NAME.startswith('pose_mae'):
losses_ali = AverageMeter()
losses_rec = AverageMeter()
accum_iter = cfg.DATA.ACCUM_ITER
model.train()
num_steps = len(train_loader)
start = time.time()
end = time.time()
for data_iter_step, (samples, keypoints, num_kps, keypoints_all) in enumerate(train_loader):
if data_iter_step % accum_iter == 0:
lr_decayed = adjust_learning_rate(optimizer, data_iter_step / num_steps + epoch, cfg.TRAIN.EPOCHS, cfg.TRAIN.WARMUP_EPOCHS, cfg.TRAIN.LR)
samples_train = samples.to(device, non_blocking=True)
with torch.cuda.amp.autocast():
if cfg.MODEL.NAME.startswith('pose_mae'):
align = True if cfg.MODEL.ALIGN > 0 else False
loss_ali, loss_rec, _, _ = model(samples_train, cfg.MODEL.MASK_RATIO, align=align, keypoints=keypoints, num_kps=num_kps, keypoints_all=keypoints_all, **mask_pose)
loss = loss_ali * cfg.MODEL.ALIGN + loss_rec
else:
if cfg.DATA.NAME == 'LUPerson':
loss, _, _ = model(samples_train, cfg.MODEL.MASK_RATIO)
elif cfg.DATA.NAME == 'LUPersonPose':
loss, _, _ = model(samples_train, cfg.MODEL.MASK_RATIO, keypoints=keypoints, num_kps=num_kps, keypoints_all=keypoints_all, **mask_pose)
loss_value = loss.item()
if not math.isfinite(loss_value):
logger.warning(f'Loss is {loss_value}, stopping training')
sys.exit(1)
loss /= accum_iter
scaler(loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
losses.update(loss_value)
if cfg.MODEL.NAME.startswith('pose_mae'):
losses_ali.update(loss_ali.item())
losses_rec.update(loss_rec.item())
batch_time.update(time.time() - end)
end = time.time()
GB = 1024. ** 3
if data_iter_step % cfg.TRAIN.PRINT_FREQ == 0:
etas = batch_time.avg * (num_steps - data_iter_step)
if cfg.MODEL.NAME.startswith('pose_mae'):
logger.info(
f'Epoch: [{epoch}/{cfg.TRAIN.EPOCHS}] ({data_iter_step}/{num_steps}) '
f'loss: {losses.val:.4f} ({losses.avg:.4f}) '
f'loss_ali: {losses_ali.val:.4f} ({losses_ali.avg:.4f}) '
f'loss_rec: {losses_rec.val:.4f} ({losses_rec.avg:.4f}) '
f'lr: {lr_decayed:.4e} '
# f'time: {batch_time.val:.4f} ({batch_time.avg:.4f}) '
# f'eta: {datetime.timedelta(seconds=int(etas))} '
)
else:
logger.info(
f'Epoch: [{epoch}/{cfg.TRAIN.EPOCHS}] ({data_iter_step}/{num_steps}) '
f'loss: {losses.val:.4f} ({losses.avg:.4f}) '
f'lr: {lr_decayed:.4e} '
# f'time: {batch_time.val:.4f} ({batch_time.avg:.4f}) '
# f'eta: {datetime.timedelta(seconds=int(etas))} '
)
if cfg.DIST.WORLD_SIZE > 1:
loss_value_reduce = torch.tensor(loss_value).cuda()
dist.all_reduce(loss_value_reduce)
loss_value_reduce_mean = loss_value_reduce / cfg.DIST.WORLD_SIZE
loss_value = loss_value_reduce_mean.item()
if summary_writer:
epoch_1000x = int((data_iter_step / num_steps + epoch) * 1000)
summary_writer.add_scalar('loss', loss_value, epoch_1000x)
summary_writer.add_scalar('lr', lr_decayed, epoch_1000x)
epoch_time = time.time() - start
epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
max_men = torch.cuda.max_memory_allocated() / GB
logger.info(f'EPOCH TIME: {epoch_time_str} GPU MEMORY: {max_men:.2f} GB')
if cfg.PRINT:
if summary_writer is not None:
summary_writer.flush()
with open(os.path.join(cfg.OUTPUT_DIR, 'result.txt'), mode='a', encoding='utf-8') as f:
if cfg.MODEL.NAME.startswith('pose_mae'):
f.write(json.dumps({'epoch': epoch, 'loss': losses.avg, 'loss_ali': losses_ali.avg, 'loss_rec': losses_rec.avg, 'lr': lr_decayed}) + '\n')
else:
f.write(json.dumps({'epoch': epoch, 'loss': losses.avg, 'lr': lr_decayed}) + '\n')
f.flush()