-
Notifications
You must be signed in to change notification settings - Fork 1
/
Spec.py
130 lines (104 loc) · 4.76 KB
/
Spec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import torch.nn as nn
import functools
class ResnetBlock(nn.Module):
"""Define a Resnet block with reflection padding and instance normalization."""
def __init__(self, dim, padding_type='reflect', norm_layer=nn.InstanceNorm2d, use_dropout=False):
"""Initialize the Resnet block.
Parameters:
dim (int) -- the number of channels in the conv layer.
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero.
norm_layer -- normalization layer.
use_dropout (bool) -- if use dropout layers.
"""
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(
dim, padding_type, norm_layer, use_dropout)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout):
"""Construct a convolutional block.
Parameters:
dim (int) -- the number of channels in the conv layer.
padding_type (str) -- the name of padding layer: reflect | replicate | zero.
norm_layer -- normalization layer.
use_dropout (bool) -- if use dropout layers.
Returns:
conv_block (nn.Sequential) -- sequential convolutional block.
"""
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
else: # zero padding
p = 1
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=True),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
# Second convolution
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
else: # zero padding
p = 1
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=True),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
"""Forward function (with skip connections)."""
return x + self.conv_block(x) # add skip connections
class Spec(nn.Module):
"""Resnet-based generator with reflection padding and instance normalization."""
def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9):
"""Construct a Resnet-based generator.
Parameters:
input_nc (int) -- the number of channels in input images.
output_nc (int) -- the number of channels in output images.
ngf (int) -- the number of filters in the last conv layer.
n_blocks (int) -- the number of ResNet blocks.
"""
super(Spec, self).__init__()
assert(n_blocks >= 0)
norm_layer = nn.InstanceNorm2d
# Initial convolution layers
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
norm_layer(ngf),
nn.ReLU(True)]
# Downsampling layers
n_downsampling = 2
for i in range(n_downsampling):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=True),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
# ResNet blocks
mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type='reflect',
norm_layer=norm_layer, use_dropout=False)]
# Upsampling layers
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1, bias=True),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
# Output layer
model += [nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
"""Standard forward function."""
return self.model(input)
if __name__ == '__main__':
model = Spec()
input = torch.randn(1, 3, 256, 256)
output = model(input)
print(output.size())