-
Notifications
You must be signed in to change notification settings - Fork 3
/
run_vgg_vanilla.py
45 lines (35 loc) · 1.21 KB
/
run_vgg_vanilla.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
import torch
import os
from src.dataset import load_data
from src.utils import img_preprocess, setup_seed, predict, eval_metric, feature_map_size
from src.utils import train
from src.models import VGG
__data_set__ = 'cifar10'
__save_ckpt__ = './checkpoints/{}/vgg_vanilla.pt'.format(__data_set__)
if not os.path.exists('./checkpoints/{}'.format(__data_set__)):
os.makedirs('./checkpoints/{}'.format(__data_set__))
# set random seed
setup_seed(2020)
# load data & preprocess
x_tr, y_tr, x_va, y_va, x_te, y_te = load_data(__data_set__)
all_tr_idx = np.arange(len(x_tr))
num_class = np.unique(y_va).shape[0]
x_tr, y_tr = img_preprocess(x_tr, y_tr,)
x_va, y_va = img_preprocess(x_va, y_va,)
x_te, y_te = img_preprocess(x_te, y_te,)
# load model
model = VGG(num_classes=num_class, dropout_rate=0.0, last_feature_map_size=feature_map_size(__data_set__))
model.cuda()
# start training model
train(model, all_tr_idx, x_tr, y_tr, x_va, y_va,
num_epoch=20,
batch_size=32, # 32
lr=1e-4,
weight_decay=0,
early_stop_ckpt_path=__save_ckpt__,
early_stop_tolerance=3)
# evaluate test acc
pred_te = predict(model, x_te)
acc_te = eval_metric(pred_te, y_te, num_class)
print("test acc:", acc_te)