forked from memesoo99/regan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_fewshot.py
61 lines (44 loc) · 1.59 KB
/
train_fewshot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch.nn.functional as F
import pickle
import torch
import numpy as np
import time
import random
import copy
import argparse
import cv2
from tqdm import tqdm
from utils.auto import load_yaml
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--config_path', help='config file path')
args = parser.parse_args()
config = load_yaml(args.config_path, args)
from model.segmentation_model import FewShotCNN
n_samples = 2
PATH = config['PATH']
device = 'cuda'
with open(config['data_dir'],"rb") as fw:
data = pickle.load(fw)
classes = config['classes'].split(',')
net = FewShotCNN(data['features'].shape[1], len(classes), size='S')
optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
labels = torch.tensor(data['labels']).long()
net.train().to(device)
start_time = time.time()
for epoch in tqdm(range(1, 100+1)):
sample_order = list(range(n_samples))
random.shuffle(sample_order)
for idx in sample_order:
sample = data['features'][idx].unsqueeze(0).to(device)
label = labels[idx].unsqueeze(0).to(device)
out = net(sample)
loss = F.cross_entropy(out, label, reduction='mean')
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 50 == 0:
print(f'{epoch:5}-th epoch | loss: {loss.item():6.4f} | time: {time.time()-start_time:6.1f}sec')
scheduler.step()
torch.save(net, PATH) # 전체 모델 저장
print('Done! model saved to ',PATH)