-
Notifications
You must be signed in to change notification settings - Fork 6
/
wct.py
50 lines (33 loc) · 1.33 KB
/
wct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
def covsqrt_mean(feature, inverse=False, tolerance=1e-14):
# I referenced the default svd tolerance value in matlab.
b, c, h, w = feature.size()
mean = torch.mean(feature.view(b, c, -1), dim=2, keepdim=True)
zeromean = feature.view(b, c, -1) - mean
cov = torch.bmm(zeromean, zeromean.transpose(1, 2))
evals, evects = torch.symeig(cov, eigenvectors=True)
p = 0.5
if inverse:
p *= -1
covsqrt = []
for i in range(b):
k = 0
for j in range(c):
if evals[i][j] > tolerance:
k = j
break
covsqrt.append(torch.mm(evects[i][:, k:],
torch.mm(evals[i][k:].pow(p).diag_embed(),
evects[i][:, k:].t())).unsqueeze(0))
covsqrt = torch.cat(covsqrt, dim=0)
return covsqrt, mean
def whitening(feature):
b, c, h, w = feature.size()
inv_covsqrt, mean = covsqrt_mean(feature, inverse=True)
normalized_feature = torch.matmul(inv_covsqrt, feature.view(b, c, -1)-mean)
return normalized_feature.view(b, c, h, w)
def coloring(feature, target):
b, c, h, w = feature.size()
covsqrt, mean = covsqrt_mean(target)
colored_feature = torch.matmul(covsqrt, feature.view(b, c, -1)) + mean
return colored_feature.view(b, c, h, w)