-
Notifications
You must be signed in to change notification settings - Fork 0
/
deform_conv.py
273 lines (211 loc) · 9.1 KB
/
deform_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
from __future__ import absolute_import, division
import torch
from torch.autograd import Variable
import numpy as np
from scipy.ndimage.interpolation import map_coordinates as sp_map_coordinates
def th_flatten(a):
"""Flatten tensor"""
return a.contiguous().view(a.nelement())
def th_repeat(a, repeats, axis=0):
"""Torch version of np.repeat for 1D"""
assert len(a.size()) == 1
return th_flatten(torch.transpose(a.repeat(repeats, 1), 0, 1))
def np_repeat_2d(a, repeats):
"""Tensorflow version of np.repeat for 2D"""
assert len(a.shape) == 2
a = np.expand_dims(a, 0)
a = np.tile(a, [repeats, 1, 1])
return a
def th_gather_2d(input, coords):
inds = coords[:, 0] * input.size(1) + coords[:, 1]
x = torch.index_select(th_flatten(input), 0, inds)
return x.view(coords.size(0))
def th_map_coordinates(input, coords, order=1):
"""Tensorflow verion of scipy.ndimage.map_coordinates
Note that coords is transposed and only 2D is supported
Parameters
----------
input : tf.Tensor. shape = (s, s)
coords : tf.Tensor. shape = (n_points, 2)
"""
assert order == 1
input_size = input.size(0)
coords = torch.clamp(coords, 0, input_size - 1)
coords_lt = coords.floor().long()
coords_rb = coords.ceil().long()
coords_lb = torch.stack([coords_lt[:, 0], coords_rb[:, 1]], 1)
coords_rt = torch.stack([coords_rb[:, 0], coords_lt[:, 1]], 1)
vals_lt = th_gather_2d(input, coords_lt.detach())
vals_rb = th_gather_2d(input, coords_rb.detach())
vals_lb = th_gather_2d(input, coords_lb.detach())
vals_rt = th_gather_2d(input, coords_rt.detach())
coords_offset_lt = coords - coords_lt.type(coords.data.type())
vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[:, 0]
vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[:, 0]
mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[:, 1]
return mapped_vals
def sp_batch_map_coordinates(inputs, coords):
"""Reference implementation for batch_map_coordinates"""
# coords = coords.clip(0, inputs.shape[1] - 1)
assert (coords.shape[2] == 2)
height = coords[:, :, 0].clip(0, inputs.shape[1] - 1)
width = coords[:, :, 1].clip(0, inputs.shape[2] - 1)
np.concatenate((np.expand_dims(height, axis=2), np.expand_dims(width, axis=2)), 2)
mapped_vals = np.array([
sp_map_coordinates(input, coord.T, mode='nearest', order=1)
for input, coord in zip(inputs, coords)
])
return mapped_vals
def th_batch_map_coordinates(input, coords, order=1):
"""Batch version of th_map_coordinates
Only supports 2D feature maps
Parameters
----------
input : tf.Tensor. shape = (b, s, s)
coords : tf.Tensor. shape = (b, n_points, 2)
Returns
-------
tf.Tensor. shape = (b, s, s)
"""
batch_size = input.size(0)
input_height = input.size(1)
input_width = input.size(2)
n_coords = coords.size(1)
# coords = torch.clamp(coords, 0, input_size - 1)
coords = torch.cat((torch.clamp(coords.narrow(2, 0, 1), 0, input_height - 1),
torch.clamp(coords.narrow(2, 1, 1), 0, input_width - 1)), 2)
assert (coords.size(1) == n_coords)
coords_lt = coords.floor().long()
coords_rb = coords.ceil().long()
coords_lb = torch.stack([coords_lt[..., 0], coords_rb[..., 1]], 2)
coords_rt = torch.stack([coords_rb[..., 0], coords_lt[..., 1]], 2)
idx = th_repeat(torch.arange(0, batch_size), n_coords).long()
idx = Variable(idx, requires_grad=False)
if input.is_cuda:
idx = idx.cuda()
def _get_vals_by_coords(input, coords):
indices = torch.stack([
idx, th_flatten(coords[..., 0]), th_flatten(coords[..., 1])
], 1)
inds = indices[:, 0] * input.size(1) * input.size(2) + indices[:, 1] * input.size(2) + indices[:, 2]
vals = th_flatten(input).index_select(0, inds)
vals = vals.view(batch_size, n_coords)
return vals
vals_lt = _get_vals_by_coords(input, coords_lt.detach())
vals_rb = _get_vals_by_coords(input, coords_rb.detach())
vals_lb = _get_vals_by_coords(input, coords_lb.detach())
vals_rt = _get_vals_by_coords(input, coords_rt.detach())
coords_offset_lt = coords - coords_lt.type(coords.data.type())
vals_t = coords_offset_lt[..., 0] * (vals_rt - vals_lt) + vals_lt
vals_b = coords_offset_lt[..., 0] * (vals_rb - vals_lb) + vals_lb
mapped_vals = coords_offset_lt[..., 1] * (vals_b - vals_t) + vals_t
return mapped_vals
def sp_batch_map_offsets(input, offsets):
"""Reference implementation for tf_batch_map_offsets"""
batch_size = input.shape[0]
input_height = input.shape[1]
input_width = input.shape[2]
offsets = offsets.reshape(batch_size, -1, 2)
grid = np.stack(np.mgrid[:input_height, :input_width], -1).reshape(-1, 2)
grid = np.repeat([grid], batch_size, axis=0)
coords = offsets + grid
# coords = coords.clip(0, input_size - 1)
mapped_vals = sp_batch_map_coordinates(input, coords)
return mapped_vals
def th_generate_grid(batch_size, input_height, input_width, dtype, cuda):
grid = np.meshgrid(
range(input_height), range(input_width), indexing='ij'
)
grid = np.stack(grid, axis=-1)
grid = grid.reshape(-1, 2)
grid = np_repeat_2d(grid, batch_size)
grid = torch.from_numpy(grid).type(dtype)
if cuda:
grid = grid.cuda()
return Variable(grid, requires_grad=False)
def th_batch_map_offsets(input, offsets, grid=None, order=1):
"""Batch map offsets into input
Parameters
---------
input : torch.Tensor. shape = (b, s, s)
offsets: torch.Tensor. shape = (b, s, s, 2)
Returns
-------
torch.Tensor. shape = (b, s, s)
"""
batch_size = input.size(0)
input_height = input.size(1)
input_width = input.size(2)
offsets = offsets.view(batch_size, -1, 2)
if grid is None:
grid = th_generate_grid(batch_size, input_height, input_width, offsets.data.type(), offsets.data.is_cuda)
coords = offsets + grid
mapped_vals = th_batch_map_coordinates(input, coords)
return mapped_vals
# Here is the ConvOffset2D
import torch.nn as nn
class ConvOffset2D(nn.Conv2d):
"""ConvOffset2D
Convolutional layer responsible for learning the 2D offsets and output the
deformed feature map using bilinear interpolation
Note that this layer does not perform convolution on the deformed feature
map. See get_deform_cnn in cnn.py for usage
"""
def __init__(self, filters, init_normal_stddev=0.01, **kwargs):
"""Init
Parameters
----------
filters : int
Number of channel of the input feature map
init_normal_stddev : float
Normal kernel initialization
**kwargs:
Pass to superclass. See Con2d layer in pytorch
"""
self.filters = filters
self._grid_param = None
super(ConvOffset2D, self).__init__(self.filters, self.filters * 2, 3, padding=1, bias=False, **kwargs)
self.weight.data.copy_(self._init_weights(self.weight, init_normal_stddev))
def forward(self, x):
"""Return the deformed featured map"""
x_shape = x.size()
offsets = super(ConvOffset2D, self).forward(x)
# offsets: (b*c, h, w, 2)
offsets = self._to_bc_h_w_2(offsets, x_shape)
# x: (b*c, h, w)
x = self._to_bc_h_w(x, x_shape)
# X_offset: (b*c, h, w)
x_offset = th_batch_map_offsets(x, offsets, grid=self._get_grid(self, x))
# x_offset: (b, h, w, c)
x_offset = self._to_b_c_h_w(x_offset, x_shape)
return x_offset
@staticmethod
def _get_grid(self, x):
batch_size, input_height, input_width = x.size(0), x.size(1), x.size(2)
dtype, cuda = x.data.type(), x.data.is_cuda
if self._grid_param == (batch_size, input_height, input_width, dtype, cuda):
return self._grid
self._grid_param = (batch_size, input_height, input_width, dtype, cuda)
self._grid = th_generate_grid(batch_size, input_height, input_width, dtype, cuda)
return self._grid
@staticmethod
def _init_weights(weights, std):
fan_out = weights.size(0)
fan_in = weights.size(1) * weights.size(2) * weights.size(3)
w = np.random.normal(0.0, std, (fan_out, fan_in))
return torch.from_numpy(w.reshape(weights.size()))
@staticmethod
def _to_bc_h_w_2(x, x_shape):
"""(b, 2c, h, w) -> (b*c, h, w, 2)"""
x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]), 2)
return x
@staticmethod
def _to_bc_h_w(x, x_shape):
"""(b, c, h, w) -> (b*c, h, w)"""
x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))
return x
@staticmethod
def _to_b_c_h_w(x, x_shape):
"""(b*c, h, w) -> (b, c, h, w)"""
x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3]))
return x