-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
153 lines (135 loc) · 5.13 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import torch.nn as nn
import geffnet
from resnest.torch import resnest101
from pretrainedmodels import se_resnext101_32x4d
sigmoid = nn.Sigmoid()
class Swish(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * sigmoid(i)
ctx.save_for_backward(i)
return result
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
class Swish_Module(nn.Module):
def forward(self, x):
return Swish.apply(x)
class Effnet_Melanoma(nn.Module):
def __init__(self, enet_type, out_dim, n_meta_features=0, n_meta_dim=[512, 128], pretrained=False):
super(Effnet_Melanoma, self).__init__()
self.n_meta_features = n_meta_features
self.enet = geffnet.create_model(enet_type, pretrained=pretrained)
self.dropouts = nn.ModuleList([
nn.Dropout(0.5) for _ in range(5)
])
in_ch = self.enet.classifier.in_features
if n_meta_features > 0:
self.meta = nn.Sequential(
nn.Linear(n_meta_features, n_meta_dim[0]),
nn.BatchNorm1d(n_meta_dim[0]),
Swish_Module(),
nn.Dropout(p=0.3),
nn.Linear(n_meta_dim[0], n_meta_dim[1]),
nn.BatchNorm1d(n_meta_dim[1]),
Swish_Module(),
)
in_ch += n_meta_dim[1]
self.myfc = nn.Linear(in_ch, out_dim)
self.enet.classifier = nn.Identity()
def extract(self, x):
x = self.enet(x)
return x
def forward(self, x, x_meta=None):
x = self.extract(x).squeeze(-1).squeeze(-1)
if self.n_meta_features > 0:
x_meta = self.meta(x_meta)
x = torch.cat((x, x_meta), dim=1)
for i, dropout in enumerate(self.dropouts):
if i == 0:
out = self.myfc(dropout(x))
else:
out += self.myfc(dropout(x))
out /= len(self.dropouts)
return out
class Resnest_Melanoma(nn.Module):
def __init__(self, enet_type, out_dim, n_meta_features=0, n_meta_dim=[512, 128], pretrained=False):
super(Resnest_Melanoma, self).__init__()
self.n_meta_features = n_meta_features
self.enet = resnest101(pretrained=pretrained)
self.dropouts = nn.ModuleList([
nn.Dropout(0.5) for _ in range(5)
])
in_ch = self.enet.fc.in_features
if n_meta_features > 0:
self.meta = nn.Sequential(
nn.Linear(n_meta_features, n_meta_dim[0]),
nn.BatchNorm1d(n_meta_dim[0]),
Swish_Module(),
nn.Dropout(p=0.3),
nn.Linear(n_meta_dim[0], n_meta_dim[1]),
nn.BatchNorm1d(n_meta_dim[1]),
Swish_Module(),
)
in_ch += n_meta_dim[1]
self.myfc = nn.Linear(in_ch, out_dim)
self.enet.fc = nn.Identity()
def extract(self, x):
x = self.enet(x)
return x
def forward(self, x, x_meta=None):
x = self.extract(x).squeeze(-1).squeeze(-1)
if self.n_meta_features > 0:
x_meta = self.meta(x_meta)
x = torch.cat((x, x_meta), dim=1)
for i, dropout in enumerate(self.dropouts):
if i == 0:
out = self.myfc(dropout(x))
else:
out += self.myfc(dropout(x))
out /= len(self.dropouts)
return out
class Seresnext_Melanoma(nn.Module):
def __init__(self, enet_type, out_dim, n_meta_features=0, n_meta_dim=[512, 128], pretrained=False):
super(Seresnext_Melanoma, self).__init__()
self.n_meta_features = n_meta_features
if pretrained:
self.enet = se_resnext101_32x4d(num_classes=1000, pretrained='imagenet')
else:
self.enet = se_resnext101_32x4d(num_classes=1000, pretrained=None)
self.enet.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.dropouts = nn.ModuleList([
nn.Dropout(0.5) for _ in range(5)
])
in_ch = self.enet.last_linear.in_features
if n_meta_features > 0:
self.meta = nn.Sequential(
nn.Linear(n_meta_features, n_meta_dim[0]),
nn.BatchNorm1d(n_meta_dim[0]),
Swish_Module(),
nn.Dropout(p=0.3),
nn.Linear(n_meta_dim[0], n_meta_dim[1]),
nn.BatchNorm1d(n_meta_dim[1]),
Swish_Module(),
)
in_ch += n_meta_dim[1]
self.myfc = nn.Linear(in_ch, out_dim)
self.enet.last_linear = nn.Identity()
def extract(self, x):
x = self.enet(x)
return x
def forward(self, x, x_meta=None):
x = self.extract(x).squeeze(-1).squeeze(-1)
if self.n_meta_features > 0:
x_meta = self.meta(x_meta)
x = torch.cat((x, x_meta), dim=1)
for i, dropout in enumerate(self.dropouts):
if i == 0:
out = self.myfc(dropout(x))
else:
out += self.myfc(dropout(x))
out /= len(self.dropouts)
return out