-
Notifications
You must be signed in to change notification settings - Fork 147
/
bd_cspn.py
77 lines (64 loc) · 2.66 KB
/
bd_cspn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from torch import Tensor, nn
from .few_shot_classifier import FewShotClassifier
class BDCSPN(FewShotClassifier):
"""
Jinlu Liu, Liang Song, Yongqiang Qin
"Prototype Rectification for Few-Shot Learning" (ECCV 2020)
https://arxiv.org/abs/1911.10713
Rectify prototypes with label propagation and feature shifting.
Classify queries based on their cosine distance to prototypes.
This is a transductive method.
"""
def rectify_prototypes(
self, query_features: Tensor
): # pylint: disable=not-callable
"""
Updates prototypes with label propagation and feature shifting.
Args:
query_features: query features of shape (n_query, feature_dimension)
"""
n_classes = self.support_labels.unique().size(0)
one_hot_support_labels = nn.functional.one_hot(self.support_labels, n_classes)
average_support_query_shift = self.support_features.mean(
0, keepdim=True
) - query_features.mean(0, keepdim=True)
query_features = query_features + average_support_query_shift
support_logits = self.cosine_distance_to_prototypes(self.support_features).exp()
query_logits = self.cosine_distance_to_prototypes(query_features).exp()
one_hot_query_prediction = nn.functional.one_hot(
query_logits.argmax(-1), n_classes
)
normalization_vector = (
(one_hot_support_labels * support_logits).sum(0)
+ (one_hot_query_prediction * query_logits).sum(0)
).unsqueeze(
0
) # [1, n_classes]
support_reweighting = (
one_hot_support_labels * support_logits
) / normalization_vector # [n_support, n_classes]
query_reweighting = (
one_hot_query_prediction * query_logits
) / normalization_vector # [n_query, n_classes]
self.prototypes = (support_reweighting * one_hot_support_labels).t().matmul(
self.support_features
) + (query_reweighting * one_hot_query_prediction).t().matmul(query_features)
def forward(
self,
query_images: Tensor,
) -> Tensor:
"""
Overrides forward method of FewShotClassifier.
Update prototypes using query images, then classify query images based
on their cosine distance to updated prototypes.
"""
query_features = self.compute_features(query_images)
self.rectify_prototypes(
query_features=query_features,
)
return self.softmax_if_specified(
self.cosine_distance_to_prototypes(query_features)
)
@staticmethod
def is_transductive() -> bool:
return True