-
Notifications
You must be signed in to change notification settings - Fork 147
/
tim.py
120 lines (101 loc) · 4.47 KB
/
tim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torch import Tensor, nn
from .few_shot_classifier import FewShotClassifier
class TIM(FewShotClassifier):
"""
Malik Boudiaf, Ziko Imtiaz Masud, Jérôme Rony, José Dolz, Pablo Piantanida, Ismail Ben Ayed.
"Transductive Information Maximization For Few-Shot Learning" (NeurIPS 2020)
https://arxiv.org/abs/2008.11297
Fine-tune prototypes based on
1) classification error on support images
2) mutual information between query features and their label predictions
Classify w.r.t. to euclidean distance to updated prototypes.
As is, it is incompatible with episodic training because we freeze the backbone to perform
fine-tuning.
TIM is a transductive method.
"""
def __init__(
self,
*args,
fine_tuning_steps: int = 50,
fine_tuning_lr: float = 1e-4,
cross_entropy_weight: float = 1.0,
marginal_entropy_weight: float = 1.0,
conditional_entropy_weight: float = 0.1,
temperature: float = 10.0,
**kwargs,
):
"""
Args:
fine_tuning_steps: number of fine-tuning steps
fine_tuning_lr: learning rate for fine-tuning
cross_entropy_weight: weight given to the cross-entropy term of the loss
marginal_entropy_weight: weight given to the marginal entropy term of the loss
conditional_entropy_weight: weight given to the conditional entropy term of the loss
temperature: temperature applied to the logits before computing
softmax or cross-entropy. Higher temperature means softer predictions.
"""
super().__init__(*args, **kwargs)
# Since we fine-tune the prototypes we need to make them leaf variables
# i.e. we need to freeze the backbone.
self.backbone.requires_grad_(False)
self.fine_tuning_steps = fine_tuning_steps
self.fine_tuning_lr = fine_tuning_lr
self.cross_entropy_weight = cross_entropy_weight
self.marginal_entropy_weight = marginal_entropy_weight
self.conditional_entropy_weight = conditional_entropy_weight
self.temperature = temperature
def forward(
self,
query_images: Tensor,
) -> Tensor:
"""
Overrides forward method of FewShotClassifier.
Fine-tune prototypes based on support classification error and mutual information between
query features and their label predictions.
Then classify w.r.t. to euclidean distance to prototypes.
"""
query_features = self.compute_features(query_images)
num_classes = self.support_labels.unique().size(0)
support_labels_one_hot = nn.functional.one_hot( # pylint: disable=not-callable
self.support_labels, num_classes
)
with torch.enable_grad():
self.prototypes.requires_grad_()
optimizer = torch.optim.Adam([self.prototypes], lr=self.fine_tuning_lr)
for _ in range(self.fine_tuning_steps):
support_logits = self.temperature * self.cosine_distance_to_prototypes(
self.support_features
)
query_logits = self.temperature * self.cosine_distance_to_prototypes(
query_features
)
support_cross_entropy = (
-(support_labels_one_hot * support_logits.log_softmax(1))
.sum(1)
.mean(0)
)
query_soft_probs = query_logits.softmax(1)
query_conditional_entropy = (
-(query_soft_probs * torch.log(query_soft_probs + 1e-12))
.sum(1)
.mean(0)
)
marginal_prediction = query_soft_probs.mean(0)
marginal_entropy = -(
marginal_prediction * torch.log(marginal_prediction)
).sum(0)
loss = self.cross_entropy_weight * support_cross_entropy - (
self.marginal_entropy_weight * marginal_entropy
- self.conditional_entropy_weight * query_conditional_entropy
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return self.softmax_if_specified(
self.cosine_distance_to_prototypes(query_features),
temperature=self.temperature,
).detach()
@staticmethod
def is_transductive() -> bool:
return True