forked from TropComplique/single-shot-detector
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
182 lines (152 loc) · 6.64 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import tensorflow as tf
from detector import SSD
from detector.anchor_generator import AnchorGenerator
from detector.box_predictor import RetinaNetBoxPredictor
from detector.feature_extractor import RetinaNetFeatureExtractor
from detector.backbones import mobilenet_v1, shufflenet_v2, resnet, hrnet
from metrics import Evaluator
MOVING_AVERAGE_DECAY = 0.993
def model_fn(features, labels, mode, params):
"""
This is a function for creating a computational tensorflow graph.
The function is in format required by tf.estimator.
"""
is_training = mode == tf.estimator.ModeKeys.TRAIN
# the base network
def backbone(images, is_training):
if params['backbone'] == 'mobilenet':
return mobilenet_v1(
images, is_training,
depth_multiplier=params['depth_multiplier']
)
elif params['backbone'] == 'shufflenet':
return shufflenet_v2(
images, is_training,
depth_multiplier=str(params['depth_multiplier'])
)
elif params['backbone'] == 'resnet':
return resnet(
images, is_training,
block_sizes=params['block_sizes'],
enableBN=params['enableBN']
)
# elif params['backbone'] == 'hrnet':
# return hrnet(
# images, is_training,
# width=params['width'],
# )
else:
raise NotImplementedError
# add additional layers to the base network
feature_extractor = RetinaNetFeatureExtractor(is_training, backbone)
# ssd anchor maker
anchor_generator = AnchorGenerator(
strides=[8, 16, 32, 64, 128],
scales=[32, 64, 128, 256, 512],
scale_multipliers=[1.0, 1.4142],
aspect_ratios=[1.0, 2.0, 0.5]
)
num_anchors_per_location = anchor_generator.num_anchors_per_location
# add layers that predict boxes and labels
box_predictor = RetinaNetBoxPredictor(is_training, params['num_classes'], num_anchors_per_location)
# collect everything on one place
ssd = SSD(
features['images'], feature_extractor,
anchor_generator, box_predictor,
params['num_classes']
)
# add nms to the graph
if not is_training:
predictions = ssd.get_predictions(
score_threshold=params['score_threshold'],
iou_threshold=params['iou_threshold'],
max_boxes_per_class=params['max_boxes_per_class']
)
if mode == tf.estimator.ModeKeys.PREDICT:
# because images are resized before
# feeding them to the network
box_scaler = features['box_scaler']
predictions['boxes'] /= box_scaler
export_outputs = tf.estimator.export.PredictOutput({
name: tf.identity(tensor, name)
for name, tensor in predictions.items()
})
return tf.estimator.EstimatorSpec(
mode, predictions=predictions,
export_outputs={'outputs': export_outputs}
)
# add l2 regularization
with tf.name_scope('weight_decay'):
add_weight_decay(params['weight_decay'])
regularization_loss = tf.losses.get_regularization_loss()
# create localization and classification losses
losses = ssd.loss(labels, params)
tf.losses.add_loss(params['localization_loss_weight'] * losses['localization_loss'])
tf.losses.add_loss(params['classification_loss_weight'] * losses['classification_loss'])
tf.summary.scalar('regularization_loss', regularization_loss)
tf.summary.scalar('localization_loss', losses['localization_loss'])
tf.summary.scalar('classification_loss', losses['classification_loss'])
total_loss = tf.losses.get_total_loss(add_regularization_losses=True)
if mode == tf.estimator.ModeKeys.EVAL:
batch_size = features['images'].shape[0].value
assert batch_size == 1
evaluator = Evaluator(num_classes=params['num_classes'])
eval_metric_ops = evaluator.get_metric_ops(labels, predictions)
return tf.estimator.EstimatorSpec(
mode, loss=total_loss,
eval_metric_ops=eval_metric_ops
)
assert mode == tf.estimator.ModeKeys.TRAIN
with tf.variable_scope('learning_rate'):
global_step = tf.train.get_global_step()
learning_rate = tf.train.cosine_decay(
params['initial_learning_rate'],
global_step, decay_steps=params['num_steps']
)
tf.summary.scalar('learning_rate', learning_rate)
# TODO: SyncBN support
if params['enableBN']:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops), tf.variable_scope('optimizer'):
var_list = tf.trainable_variables()
if params.has_key('freeze_at'):
# remove freezed vars from var_list
pass
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
grads_and_vars = optimizer.compute_gradients(total_loss, var_list)
train_op = optimizer.apply_gradients(grads_and_vars, global_step)
for g, v in grads_and_vars:
tf.summary.histogram(v.name[:-2] + '_hist', v)
tf.summary.histogram(v.name[:-2] + '_grad_hist', g)
# TODO: chech if ema helps.
with tf.control_dependencies([train_op]), tf.name_scope('ema'):
ema = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY, num_updates=global_step)
train_op = ema.apply(tf.trainable_variables())
return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op)
def add_weight_decay(weight_decay):
"""Add L2 regularization to all (or some) trainable kernel weights."""
weight_decay = tf.constant(
weight_decay, tf.float32,
[], 'weight_decay'
)
trainable_vars = tf.trainable_variables()
kernels = [
v for v in trainable_vars
if ('weights' in v.name or 'kernel' in v.name) and 'depthwise_weights' not in v.name
]
for K in kernels:
x = tf.multiply(weight_decay, tf.nn.l2_loss(K))
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, x)
class RestoreMovingAverageHook(tf.train.SessionRunHook):
def __init__(self, model_dir):
super(RestoreMovingAverageHook, self).__init__()
self.model_dir = model_dir
def begin(self):
ema = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY)
variables_to_restore = ema.variables_to_restore()
self.load_ema = tf.contrib.framework.assign_from_checkpoint_fn(
tf.train.latest_checkpoint(self.model_dir), variables_to_restore
)
def after_create_session(self, sess, coord):
tf.logging.info('Loading EMA weights...')
self.load_ema(sess)