-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
39 lines (34 loc) · 1.59 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import wandb
import numpy as np
IMAGE_SHAPE = (32, 32)
def plot_spatial_noise_distribution(noise, predicted_noise):
plt.figure()
f, ax = plt.subplots(1, 2, figsize = (5,5))
ax[0].imshow(reverse_transform(noise))
ax[0].set_title(f"ground truth noise", fontsize = 10)
ax[1].imshow(reverse_transform(predicted_noise))
ax[1].set_title(f"predicted noise", fontsize = 10)
plt.show()
wandb.log({'spatial noise distributions': wandb.Image(f)})
def plot_noise_distribution(noise, predicted_noise):
plt.figure()
f, ax = plt.subplots(1, 1, figsize = (5,5))
ax.hist(noise.cpu().numpy().flatten(), density = True, alpha = 0.8, label = "ground truth noise")
ax.hist(predicted_noise.cpu().numpy().flatten(), density = True, alpha = 0.8, label = "predicted noise")
ax.legend()
plt.show()
wandb.log({'noise distributions': wandb.Image(f)})
transform = transforms.Compose([
transforms.Resize(IMAGE_SHAPE), # Resize the input image
transforms.ToTensor(), # Convert to torch tensor (scales data into [0,1])
transforms.Lambda(lambda t: (t * 2) - 1), # Scale data between [-1, 1]
])
reverse_transform = transforms.Compose([
transforms.Lambda(lambda t: (t + 1) / 2), # Scale data between [0,1]
transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
transforms.Lambda(lambda t: t * 255.), # Scale data between [0.,255.]
transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)), # Convert into an uint8 numpy array
transforms.ToPILImage(), # Convert to PIL image
])