-
Notifications
You must be signed in to change notification settings - Fork 562
/
self_tuning.py
189 lines (156 loc) · 7.37 KB
/
self_tuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
Adapted from https://github.com/thuml/Self-Tuning/tree/master
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import torch
import torch.nn as nn
from torch.nn.functional import normalize
from tllib.modules.classifier import Classifier as ClassifierBase
class Classifier(ClassifierBase):
"""Classifier class for Self-Tuning.
Args:
backbone (torch.nn.Module): Any backbone to extract 2-d features from data
num_classes (int): Number of classes.
projection_dim (int, optional): Dimension of the projector head. Default: 128
finetune (bool): Whether finetune the classifier or train from scratch. Default: True
Inputs:
- x (tensor): input data fed to `backbone`
Outputs:
In the training mode,
- h: projections
- y: classifier's predictions
In the eval mode,
- y: classifier's predictions
Shape:
- Inputs: (minibatch, *) where * means, any number of additional dimensions
- y: (minibatch, `num_classes`)
- h: (minibatch, `projection_dim`)
"""
def __init__(self, backbone: nn.Module, num_classes: int, projection_dim=1024, bottleneck_dim=1024, finetune=True,
pool_layer=None):
bottleneck = nn.Sequential(
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Dropout(0.5)
)
bottleneck[0].weight.data.normal_(0, 0.005)
bottleneck[0].bias.data.fill_(0.1)
head = nn.Linear(1024, num_classes)
super(Classifier, self).__init__(backbone, num_classes=num_classes, head=head, finetune=finetune,
pool_layer=pool_layer, bottleneck=bottleneck, bottleneck_dim=bottleneck_dim)
self.projector = nn.Linear(1024, projection_dim)
self.projection_dim = projection_dim
def forward(self, x: torch.Tensor):
f = self.pool_layer(self.backbone(x))
f = self.bottleneck(f)
# projections
h = self.projector(f)
h = normalize(h, dim=1)
# predictions
predictions = self.head(f)
if self.training:
return h, predictions
else:
return predictions
def get_parameters(self, base_lr=1.0):
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head.parameters(), "lr": 1.0 * base_lr},
{"params": self.projector.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
]
return params
class SelfTuning(nn.Module):
r"""Self-Tuning module in `Self-Tuning for Data-Efficient Deep Learning (self-tuning, ICML 2021)
<http://ise.thss.tsinghua.edu.cn/~mlong/doc/Self-Tuning-for-Data-Efficient-Deep-Learning-icml21.pdf>`_.
Args:
encoder_q (Classifier): Query encoder.
encoder_k (Classifier): Key encoder.
num_classes (int): Number of classes
K (int): Queue size. Default: 32
m (float): Momentum coefficient. Default: 0.999
T (float): Temperature. Default: 0.07
Inputs:
- im_q (tensor): input data fed to `encoder_q`
- im_k (tensor): input data fed to `encoder_k`
- labels (tensor): classification labels of input data
Outputs: pgc_logits, pgc_labels, y_q
- pgc_logits: projector's predictions on both positive and negative samples
- pgc_labels: contrastive labels
- y_q: query classifier's predictions
Shape:
- im_q, im_k: (minibatch, *) where * means, any number of additional dimensions
- labels: (minibatch, )
- y_q: (minibatch, `num_classes`)
- pgc_logits: (minibatch, 1 + `num_classes` :math:`\times` `K`, `projection_dim`)
- pgc_labels: (minibatch, 1 + `num_classes` :math:`\times` `K`)
"""
def __init__(self, encoder_q, encoder_k, num_classes, K=32, m=0.999, T=0.07):
super(SelfTuning, self).__init__()
self.K = K
self.m = m
self.T = T
self.num_classes = num_classes
# create the encoders
# num_classes is the output fc dimension
self.encoder_q = encoder_q
self.encoder_k = encoder_k
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
# create the queue
self.register_buffer("queue_list", torch.randn(encoder_q.projection_dim, K * self.num_classes))
self.queue_list = normalize(self.queue_list, dim=0)
self.register_buffer("queue_ptr", torch.zeros(self.num_classes, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, h, label):
# gather keys before updating queue
batch_size = h.shape[0]
ptr = int(self.queue_ptr[label])
real_ptr = ptr + label * self.K
# replace the keys at ptr (dequeue and enqueue)
self.queue_list[:, real_ptr:real_ptr + batch_size] = h.T
# move pointer
ptr = (ptr + batch_size) % self.K
self.queue_ptr[label] = ptr
def forward(self, im_q, im_k, labels):
batch_size = im_q.size(0)
device = im_q.device
# compute query features
h_q, y_q = self.encoder_q(im_q) # queries: h_q (N x projection_dim)
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
h_k, _ = self.encoder_k(im_k) # keys: h_k (N x projection_dim)
# compute logits
# positive logits: Nx1
logits_pos = torch.einsum('nl,nl->n', [h_q, h_k]).unsqueeze(-1) # Einstein sum is more intuitive
# cur_queue_list: queue_size * class_num
cur_queue_list = self.queue_list.clone().detach()
logits_neg_list = torch.Tensor([]).to(device)
logits_pos_list = torch.Tensor([]).to(device)
for i in range(batch_size):
neg_sample = torch.cat([cur_queue_list[:, 0:labels[i] * self.K],
cur_queue_list[:, (labels[i] + 1) * self.K:]],
dim=1)
pos_sample = cur_queue_list[:, labels[i] * self.K: (labels[i] + 1) * self.K]
ith_neg = torch.einsum('nl,lk->nk', [h_q[i:i + 1], neg_sample])
ith_pos = torch.einsum('nl,lk->nk', [h_q[i:i + 1], pos_sample])
logits_neg_list = torch.cat((logits_neg_list, ith_neg), dim=0)
logits_pos_list = torch.cat((logits_pos_list, ith_pos), dim=0)
self._dequeue_and_enqueue(h_k[i:i + 1], labels[i])
# logits: 1 + queue_size + queue_size * (class_num - 1)
pgc_logits = torch.cat([logits_pos, logits_pos_list, logits_neg_list], dim=1)
pgc_logits = nn.LogSoftmax(dim=1)(pgc_logits / self.T)
pgc_labels = torch.zeros([batch_size, 1 + self.K * self.num_classes]).to(device)
pgc_labels[:, 0:self.K + 1].fill_(1.0 / (self.K + 1))
return pgc_logits, pgc_labels, y_q