This repository has been archived by the owner on Dec 13, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 41
/
data.py
135 lines (98 loc) · 4.95 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torchlib
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
class OnlyImage(Dataset):
def __init__(self, img_label_dataset):
self.img_label_dataset = img_label_dataset
def __len__(self):
return len(self.img_label_dataset)
def __getitem__(self, i):
return self.img_label_dataset[i][0]
def make_32x32_dataset(dataset, batch_size, drop_remainder=True, shuffle=True, num_workers=4, pin_memory=False):
if dataset == 'mnist':
transform = transforms.Compose([
transforms.Resize(size=(32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
dataset = datasets.MNIST('data/MNIST', transform=transform, download=True)
img_shape = [32, 32, 1]
elif dataset == 'fashion_mnist':
transform = transforms.Compose([
transforms.Resize(size=(32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
dataset = datasets.FashionMNIST('data/FashionMNIST', transform=transform, download=True)
img_shape = [32, 32, 1]
elif dataset == 'cifar10':
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = datasets.CIFAR10('data/CIFAR10', transform=transform, download=True)
img_shape = [32, 32, 3]
else:
raise NotImplementedError
dataset = OnlyImage(dataset)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_remainder, pin_memory=pin_memory)
return data_loader, img_shape
def make_celeba_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, num_workers=4, pin_memory=False):
crop_size = 108
offset_height = (218 - crop_size) // 2
offset_width = (178 - crop_size) // 2
crop = lambda x: x[:, offset_height:offset_height + crop_size, offset_width:offset_width + crop_size]
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(crop),
transforms.ToPILImage(),
transforms.Resize(size=(resize, resize)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = torchlib.DiskImageDataset(img_paths, map_fn=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_remainder, pin_memory=pin_memory)
img_shape = (resize, resize, 3)
return data_loader, img_shape
def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, num_workers=4, pin_memory=False):
transform = transforms.Compose([
transforms.Resize(size=(resize, resize)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = torchlib.DiskImageDataset(img_paths, map_fn=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_remainder, pin_memory=pin_memory)
img_shape = (resize, resize, 3)
return data_loader, img_shape
# ==============================================================================
# = custom dataset =
# ==============================================================================
def make_custom_datset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, num_workers=4, pin_memory=False):
transform = transforms.Compose([
# ======================================
# = custom =
# ======================================
..., # custom preprocessings
# ======================================
# = custom =
# ======================================
transforms.Resize(size=(resize, resize)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = torchlib.DiskImageDataset(img_paths, map_fn=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_remainder, pin_memory=pin_memory)
img_shape = (resize, resize, 3)
return data_loader, img_shape
# ==============================================================================
# = debug =
# ==============================================================================
# import imlib as im
# import numpy as np
# import pylib as py
# data_loader, _ = make_celeba_dataset(py.glob('data/img_align_celeba', '*.jpg'), batch_size=64)
# for img_batch in data_loader:
# for img in img_batch.numpy():
# img = np.transpose(img, (1, 2, 0))
# im.imshow(img)
# im.show()