-
Notifications
You must be signed in to change notification settings - Fork 1
/
models.py
102 lines (75 loc) · 2.91 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.nn import Parameter
from breg_next import BReGNeXt
from gcn import GraphConvolution
from utils import gen_A, gen_adj
class Emotion_GCN(nn.Module):
"""
Based on the code of ML-GCN https://github.com/Megvii-Nanjing/ML-GCN
"""
def __init__(self, adj_file=None, in_channel=300, input_size=227):
super(Emotion_GCN, self).__init__()
self.features = models.densenet121(pretrained=True).features
if input_size == 227:
self.pooling = nn.MaxPool2d(7, 7)
else:
self.pooling = nn.MaxPool2d(3, 3)
self.gc1 = GraphConvolution(in_channel, 512)
self.gc2 = GraphConvolution(512, 1024)
self.relu = nn.LeakyReLU(0.2)
_adj = gen_A(adj_file)
self.A = Parameter(torch.from_numpy(_adj).float())
print(self.A)
def forward(self, feature, inp):
feature = self.features(feature)
feature = self.pooling(feature)
feature = feature.view(feature.size(0), -1)
inp = inp[0]
adj = gen_adj(self.A).detach()
x = self.gc1(inp, adj)
x = self.relu(x)
x = self.gc2(x, adj)
x = x.transpose(0, 1)
x = torch.matmul(feature, x)
return x[:, :7], x[:, 7:]
class multi_densenet(nn.Module):
def __init__(self, pretrained=True, num_categorical=7):
super(multi_densenet, self).__init__()
self.model_base = models.densenet121(pretrained=pretrained).features
self.num_categorical = num_categorical
self.num_continuous = 2
self.lin_cat = nn.Linear(1024, self.num_categorical)
self.lin_cont = nn.Linear(1024, self.num_continuous)
def forward(self, x):
feat = self.model_base(x)
out = F.relu(feat, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
out_cont = self.lin_cont(out)
out_cat = self.lin_cat(out)
return out_cat, out_cont
class BReGNeXt_GCN(nn.Module):
def __init__(self, adj_file=None, in_channel=300):
super(BReGNeXt_GCN, self).__init__()
self.model = BReGNeXt(n_classes=7)
self.gc1 = GraphConvolution(in_channel, 512)
self.gc2 = GraphConvolution(512, 128)
self.relu = nn.LeakyReLU(0.2)
_adj = gen_A(adj_file)
self.A = Parameter(torch.from_numpy(_adj).float())
print(self.A)
def forward(self, feature, inp):
feature = torch.nn.functional.pad(feature, (1,1,1,1,0,0))
feature = self.model._conv0(feature)
feature = self.model._model(feature).reshape(-1, 128)
inp = inp[0]
adj = gen_adj(self.A).detach()
x = self.gc1(inp, adj)
x = self.relu(x)
x = self.gc2(x, adj)
x = x.transpose(0, 1)
x = torch.matmul(feature, x)
return x[:, :7], x[:, 7:]