-
Notifications
You must be signed in to change notification settings - Fork 1
/
dknn.py
128 lines (107 loc) · 4.28 KB
/
dknn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from sklearn.neighbors import NearestNeighbors
import torch.nn as nn
import numpy as np
from tqdm import tqdm
class DKNN:
def __init__(
self,
model,
train_data,
train_targets,
n_class=10,
hidden_layers=-1,
n_neighbors=5,
metric="l2",
batch_size=128,
device=torch.device("cpu")
):
self.hidden_layers = hidden_layers
self._model = self._wrap_model(model)
self.train_data = train_data
self.train_targets = np.array(train_targets)
self.hidden_layers = self._model.hidden_layers
self.device = device
self._model.eval() # make sure the model is in the eval mode
self._model.to(self.device)
self.n_class = n_class
self.metric = metric
self.n_neighbors = n_neighbors
self.batch_size = batch_size
self._nns = self._build_nns()
def _get_hidden_repr(self, x, return_targets=False):
hidden_reprs = []
targets = None
if return_targets:
outs = []
for i in range(0, x.size(0), self.batch_size):
x_batch = x[i:i + self.batch_size]
if return_targets:
hidden_reprs_batch, outs_batch = self._model(x_batch.to(self.device))
else:
hidden_reprs_batch, _ = self._model(x_batch.to(self.device))
if self.metric == "cosine":
hidden_reprs_batch = [
hidden_repr_batch / hidden_repr_batch.pow(2).sum(dim=1, keepdim=True).sqrt()
for hidden_repr_batch in hidden_reprs_batch
]
hidden_reprs.append(hidden_reprs_batch)
if return_targets:
outs.append(outs_batch)
hidden_reprs = [
np.concatenate([hidden_batch[i] for hidden_batch in hidden_reprs], axis=0)
for i in range(len(self.hidden_layers))
]
if return_targets:
outs = np.concatenate(outs, axis=0)
targets = outs.argmax(axis=1)
return hidden_reprs, targets
def _wrap_model(self, model):
class ModelWrapper(nn.Module):
def __init__(self, model, hidden_layers):
super(ModelWrapper, self).__init__()
self._model = model
self.hidden_mappings = []
start_layer = 0
if hasattr(model, "feature"):
start_layer = 1
self.hidden_mappings.append(model.feature)
self.hidden_mappings.extend([
m[1] for m in model.named_children()
if isinstance(m[1], nn.Sequential) and "layer" in m[0]
])
if hidden_layers == -1:
self.hidden_layers = list(range(len(self.hidden_mappings)))
else:
self.hidden_layers = hidden_layers
self.hidden_layers = [hl + start_layer for hl in hidden_layers]
self.classifier = self._model.classifier
def forward(self, x):
hidden_reprs = []
for mp in self.hidden_mappings:
x = mp(x)
hidden_reprs.append(x.detach().cpu())
out = self.classifier(x)
return [hidden_reprs[i].flatten(start_dim=1) for i in self.hidden_layers], out
return ModelWrapper(model, self.hidden_layers)
def _build_nns(self):
hidden_reprs, _ = self._get_hidden_repr(self.train_data)
return [
NearestNeighbors(n_neighbors=self.n_neighbors, n_jobs=-1).fit(hidden_repr)
for hidden_repr in tqdm(hidden_reprs)
]
def predict(self, x):
hidden_reprs, _ = self._get_hidden_repr(x)
nn_indices = [
nn.kneighbors(hidden_repr, return_distance=False)
for nn, hidden_repr in zip(self._nns, hidden_reprs)
]
nn_indices = np.concatenate(nn_indices, axis=1)
nn_labels = self.train_targets[nn_indices]
nn_labels_count = np.stack(list(map(
lambda x: np.bincount(x, minlength=10),
nn_labels
)))
return nn_labels_count / len(self.hidden_layers) / self.n_neighbors
def __call__(self, x):
return self.predict(x)