-
Notifications
You must be signed in to change notification settings - Fork 3
/
encode_image.py
147 lines (120 loc) · 4.94 KB
/
encode_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from data_utils import *
dtype = torch.cuda.FloatTensor
path = '/flikr8k/'
output_path = '/output/'
def preprocess(img, augment=False):
if augment:
transform = T.Compose([
T.Resize((224, 224)),
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
T.RandomResizedCrop(224, scale=(0.75, 1.0)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
T.Lambda(lambda x: x[None]),
])
else:
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
T.Lambda(lambda x: x[None]),
])
return transform(img)
def batchify(files):
xs = []
for im_id in files:
im = Image.open(path+'Flicker8k_Dataset/'+im_id)
x = preprocess(im)
xs.append(x)
return torch.cat(xs)
def feed_forward_net(xs, net, pre_net, spatial=False):
if pre_net == 'inception_v3':
for module_name in list(net._modules.keys())[:3]:
xs = net._modules[module_name](xs)
xs = F.max_pool2d(xs, kernel_size=3, stride=2)
for module_name in list(net._modules.keys())[3:5]:
xs = net._modules[module_name](xs)
xs = F.max_pool2d(xs, kernel_size=3, stride=2)
for module_name in list(net._modules.keys())[5:-5]:
xs = net._modules[module_name](xs)
for module_name in list(net._modules.keys())[-4:-1]:
xs = net._modules[module_name](xs)
xs = F.avg_pool2d(xs, kernel_size=8)
return xs.squeeze()
if pre_net == 'densenet161':
xs = net.features(xs)
xs = F.relu(xs, inplace=True)
xs = F.avg_pool2d(xs, kernel_size=7, stride=1).view(xs.size(0), -1)
return xs.squeeze()
if pre_net == 'resnet101':
l = 2 if spatial else 1
net_modules = list(net._modules.keys())[:-l]
for module_name in net_modules:
xs = net._modules[module_name](xs)
return xs if spatial else xs.squeeze()
if pre_net == 'vgg16':
xs = net.features(xs)
if spatial:
return xs
else:
xs = xs.view(xs.size(0), -1)
xs = net.classifier(xs)
return xs
def encode_image(files, batch_size, spatial, pre_net):
if pre_net == 'inception_v3':
net = models.inception_v3(pretrained=True)
if pre_net == 'densenet161':
net = models.densenet161(pretrained=True)
if pre_net == 'resnet101':
net = models.resnet101(pretrained=True)
if pre_net == 'vgg16':
net = models.vgg16_bn(pretrained=True)
net.type(dtype)
net.eval()
encoded_images = {}
num_samples = len(files)
num_batches = num_samples // batch_size
if num_samples % batch_size != 0: num_batches += 1
for i in range(num_batches):
start = i * batch_size
end = (i + 1) * batch_size
xs = batchify(files[start:end]).type(dtype)
xs = feed_forward_net(xs, net, pre_net, spatial)
xs = xs.cpu().detach().numpy()
for j in range(xs.shape[0]):
encoded_images[files[start:end][j]] = xs[j]
return encoded_images
def main(args):
trn_files, test_files = split_image_files(path)
train_features = encode_image(trn_files, args.batch_size, args.spatial, args.pre_net)
test_features = encode_image(test_files, args.batch_size, args.spatial, args.pre_net)
print('number of train samples:', len(train_features))
print('train features shape:', list(train_features.values())[0].shape)
print('number of test samples:', len(test_features))
print('test features shape:', list(test_features.values())[0].shape)
trn_feats, test_feats = {}, {}
trn_feats['train_features'] = train_features
test_feats['test_features'] = test_features
with open(output_path+args.train_feat_path+'.pkl', 'wb') as f:
pickle.dump(trn_feats, f)
with open(output_path+args.test_feat_path+'.pkl', 'wb') as f:
pickle.dump(test_feats, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train_feat_path', type=str,
default='train_feat_arrays',
help='path for train image features file')
parser.add_argument('--test_feat_path', type=str,
default='test_feat_array',
help='path for test image features file')
parser.add_argument('--batch_size', type=int, default=64,
help='batch size to transforming images')
parser.add_argument('--spatial', type=bool, default=0,
help='whether to remain spatial info in features')
parser.add_argument('--pre-net', type=str, default='resnet101',
help='which pretrained cnn to use')
args = parser.parse_args()
main(args)