-
Notifications
You must be signed in to change notification settings - Fork 158
/
config.py
133 lines (112 loc) · 4.68 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import pandas as pd
import torch
from transforms import transforms
from utils.autoaugment import ImageNetPolicy
# pretrained model checkpoints
pretrained_model = {'resnet50' : './models/pretrained/resnet50-19c8e357.pth',}
# transforms dict
def load_data_transformers(resize_reso=512, crop_reso=448, swap_num=[7, 7]):
center_resize = 600
Normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
data_transforms = {
'swap': transforms.Compose([
transforms.Randomswap((swap_num[0], swap_num[1])),
]),
'common_aug': transforms.Compose([
transforms.Resize((resize_reso, resize_reso)),
transforms.RandomRotation(degrees=15),
transforms.RandomCrop((crop_reso,crop_reso)),
transforms.RandomHorizontalFlip(),
]),
'train_totensor': transforms.Compose([
transforms.Resize((crop_reso, crop_reso)),
# ImageNetPolicy(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]),
'val_totensor': transforms.Compose([
transforms.Resize((crop_reso, crop_reso)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]),
'test_totensor': transforms.Compose([
transforms.Resize((resize_reso, resize_reso)),
transforms.CenterCrop((crop_reso, crop_reso)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]),
'None': None,
}
return data_transforms
class LoadConfig(object):
def __init__(self, args, version):
if version == 'train':
get_list = ['train', 'val']
elif version == 'val':
get_list = ['val']
elif version == 'test':
get_list = ['test']
else:
raise Exception("train/val/test ???\n")
###############################
#### add dataset info here ####
###############################
# put image data in $PATH/data
# put annotation txt file in $PATH/anno
if args.dataset == 'product':
self.dataset = args.dataset
self.rawdata_root = './../FGVC_product/data'
self.anno_root = './../FGVC_product/anno'
self.numcls = 2019
elif args.dataset == 'CUB':
self.dataset = args.dataset
self.rawdata_root = './dataset/CUB_200_2011/data'
self.anno_root = './dataset/CUB_200_2011/anno'
self.numcls = 200
elif args.dataset == 'STCAR':
self.dataset = args.dataset
self.rawdata_root = './dataset/st_car/data'
self.anno_root = './dataset/st_car/anno'
self.numcls = 196
elif args.dataset == 'AIR':
self.dataset = args.dataset
self.rawdata_root = './dataset/aircraft/data'
self.anno_root = './dataset/aircraft/anno'
self.numcls = 100
else:
raise Exception('dataset not defined ???')
# annotation file organized as :
# path/image_name cls_num\n
if 'train' in get_list:
self.train_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_train.txt'),\
sep=" ",\
header=None,\
names=['ImageName', 'label'])
if 'val' in get_list:
self.val_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_val.txt'),\
sep=" ",\
header=None,\
names=['ImageName', 'label'])
if 'test' in get_list:
self.test_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_test.txt'),\
sep=" ",\
header=None,\
names=['ImageName', 'label'])
self.swap_num = args.swap_num
self.save_dir = './net_model'
if not os.path.exists(self.save_dir):
os.mkdir(self.save_dir)
self.backbone = args.backbone
self.use_dcl = True
self.use_backbone = False if self.use_dcl else True
self.use_Asoftmax = False
self.use_focal_loss = False
self.use_fpn = False
self.use_hier = False
self.weighted_sample = False
self.cls_2 = True
self.cls_2xmul = False
self.log_folder = './logs'
if not os.path.exists(self.log_folder):
os.mkdir(self.log_folder)