-
Notifications
You must be signed in to change notification settings - Fork 40
/
finetune.py
182 lines (156 loc) · 5.82 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import argparse
import os
import torch
import clip
import os
from tqdm import tqdm
import time
from timm.data.transforms_factory import transforms_imagenet_train
from datasets.imagenet import ImageNet98p, ImageNet
from utils import ModelWrapper, maybe_dictionarize_batch, cosine_lr
from zeroshot import zeroshot_classifier
from openai_imagenet_template import openai_imagenet_template
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-location",
type=str,
default=os.path.expanduser('~/data'),
help="The root directory for the datasets.",
)
parser.add_argument(
"--model-location",
type=str,
default=os.path.expanduser('~/ssd/checkpoints/soups'),
help="Where to download the models.",
)
parser.add_argument(
"--batch-size",
type=int,
default=256,
)
parser.add_argument(
"--custom-template", action="store_true", default=False,
)
parser.add_argument(
"--workers",
type=int,
default=8,
)
parser.add_argument(
"--epochs",
type=int,
default=8,
)
parser.add_argument(
"--warmup-length",
type=int,
default=500,
)
parser.add_argument(
"--lr",
type=float,
default=2e-5,
)
parser.add_argument(
"--wd",
type=float,
default=0.1,
)
parser.add_argument(
"--model",
default='ViT-B/32',
help='Model to use -- you can try another like ViT-L/14'
)
parser.add_argument(
"--name",
default='finetune_cp',
help='Filename for the checkpoints.'
)
parser.add_argument(
"--timm-aug", action="store_true", default=False,
)
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
DEVICE = 'cuda'
if args.custom_template:
template = [lambda x : f"a photo of a {x}."]
else:
template = openai_imagenet_template
base_model, preprocess = clip.load(args.model, 'cuda', jit=False)
# 98p is the 98% of ImageNet train set that we train on -- the other 2% is hodl-out val.
if args.timm_aug:
train_preprocess = transforms_imagenet_train(
img_size=base_model.visual.input_resolution,
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)
)
else:
train_preprocess = preprocess
train_dset = ImageNet98p(train_preprocess, location=args.data_location, batch_size=args.batch_size, num_workers=args.workers)
test_dset = ImageNet(preprocess, location=args.data_location, batch_size=args.batch_size, num_workers=args.workers)
clf = zeroshot_classifier(base_model, train_dset.classnames, template, DEVICE)
NUM_CLASSES = len(train_dset.classnames)
feature_dim = base_model.visual.output_dim
model = ModelWrapper(base_model, feature_dim, NUM_CLASSES, normalize=True, initial_weights=clf)
for p in model.parameters():
p.data = p.data.float()
model = model.cuda()
devices = [x for x in range(torch.cuda.device_count())]
model = torch.nn.DataParallel(model, device_ids=devices)
model_parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(model_parameters, lr=args.lr, weight_decay=args.wd)
num_batches = len(train_dset.train_loader)
scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches)
loss_fn = torch.nn.CrossEntropyLoss()
model_path = os.path.join(args.model_location, f'{args.name}_0.pt')
print('Saving model to', model_path)
torch.save(model.module.state_dict(), model_path)
for epoch in range(args.epochs):
# Train
model.train()
end = time.time()
for i, batch in enumerate(train_dset.train_loader):
step = i + epoch * num_batches
scheduler(step)
optimizer.zero_grad()
batch = maybe_dictionarize_batch(batch)
inputs, labels = batch['images'].to(DEVICE), batch['labels'].to(DEVICE)
data_time = time.time() - end
logits = model(inputs)
loss = loss_fn(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
batch_time = time.time() - end
end = time.time()
if i % 20 == 0:
percent_complete = 100.0 * i / len(train_dset.train_loader)
print(
f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(train_dset.train_loader)}]\t"
f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True
)
# #Evaluate
test_loader = test_dset.test_loader
model.eval()
with torch.no_grad():
print('*'*80)
print('Starting eval')
correct, count = 0.0, 0.0
pbar = tqdm(test_loader)
for batch in pbar:
batch = maybe_dictionarize_batch(batch)
inputs, labels = batch['images'].to(DEVICE), batch['labels'].to(DEVICE)
logits = model(inputs)
loss = loss_fn(logits, labels)
pred = logits.argmax(dim=1, keepdim=True)
correct += pred.eq(labels.view_as(pred)).sum().item()
count += len(logits)
pbar.set_description(
f"Val loss: {loss.item():.4f} Acc: {100*correct/count:.2f}")
top1 = correct / count
print(f'Val acc at epoch {epoch}: {100*top1:.2f}')
model_path = os.path.join(args.model_location, f'{args.name}_{epoch + 1}.pt')
print('Saving model to', model_path)
torch.save(model.module.state_dict(), model_path)