-
Notifications
You must be signed in to change notification settings - Fork 1
/
blocks.py
123 lines (96 loc) · 3.63 KB
/
blocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
This is used for building up attention layers and resnet layers.
From this source.
! https://github.com/qubvel/residual_attention_network/blob/master/models/blocks.py
"""
# this are network blocks
from keras.layers import BatchNormalization
from keras.layers import Conv2D
from keras.layers import UpSampling2D
from keras.layers import Activation
from keras.layers import MaxPool2D
from keras.layers import Add
from keras.layers import Multiply
from keras.layers import Lambda
def residual_block(input, input_channels=None, output_channels=None, kernel_size=(3, 3), stride=1):
"""
full pre-activation residual block
https://arxiv.org/pdf/1603.05027.pdf
"""
if output_channels is None:
output_channels = input.get_shape()[-1]
if input_channels is None:
input_channels = output_channels // 4
strides = (stride, stride)
x = BatchNormalization()(input)
x = Activation('relu')(x)
x = Conv2D(input_channels, (1, 1))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(input_channels, kernel_size, padding='same', strides=stride)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(output_channels, (1, 1), padding='same')(x)
if input_channels != output_channels or stride != 1:
input = Conv2D(output_channels, (1, 1), padding='same', strides=strides)(input)
x = Add()([x, input])
return x
def attention_block(input, input_channels=None, output_channels=None, encoder_depth=1):
"""
attention block
https://arxiv.org/abs/1704.06904
"""
p = 1
t = 2
r = 1
if input_channels is None:
input_channels = input.get_shape()[-1]
if output_channels is None:
output_channels = input_channels
# First Residual Block
for i in range(p):
input = residual_block(input)
# Trunc Branch
output_trunk = input
for i in range(t):
output_trunk = residual_block(output_trunk)
# Soft Mask Branch
## encoder
### first down sampling
output_soft_mask = MaxPool2D(padding='same')(input) # 32x32
for i in range(r):
output_soft_mask = residual_block(output_soft_mask)
skip_connections = []
for i in range(encoder_depth - 1):
## skip connections
output_skip_connection = residual_block(output_soft_mask)
skip_connections.append(output_skip_connection)
# print ('skip shape:', output_skip_connection.get_shape())
## down sampling
output_soft_mask = MaxPool2D(padding='same')(output_soft_mask)
for _ in range(r):
output_soft_mask = residual_block(output_soft_mask)
## decoder
skip_connections = list(reversed(skip_connections))
for i in range(encoder_depth - 1):
## upsampling
for _ in range(r):
output_soft_mask = residual_block(output_soft_mask)
output_soft_mask = UpSampling2D()(output_soft_mask)
## skip connections
output_soft_mask = Add()([output_soft_mask, skip_connections[i]])
### last upsampling
for i in range(r):
output_soft_mask = residual_block(output_soft_mask)
output_soft_mask = UpSampling2D()(output_soft_mask)
## Output
output_soft_mask = Conv2D(input_channels, (1, 1))(output_soft_mask)
output_soft_mask = Conv2D(input_channels, (1, 1))(output_soft_mask)
output_soft_mask = Activation('sigmoid')(output_soft_mask)
# Attention: (1 + output_soft_mask) * output_trunk
output = Lambda(lambda x: x + 1)(output_soft_mask)
output = Multiply()([output, output_trunk]) #
# Last Residual Block
for i in range(p):
output = residual_block(output)
return output