-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_pool.py
66 lines (60 loc) · 2.68 KB
/
image_pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Image pool to train discriminators on samples of past images.
Reference:
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/util/image_pool.py
"""
import random
import jax.numpy as jnp
class ImagePool:
"""This class implements an image buffer that stores previously generated images.
This buffer enables us to update discriminators using a history of generated images
rather than the ones produced by the latest generators.
"""
def __init__(self, pool_size):
"""Initialize the ImagePool class
Parameters:
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
"""
self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool
self.num_imgs = 0
self.images = []
def query(self, images):
"""Return an image from the pool.
Parameters:
images: the latest generated images from the generator
Returns images from the buffer.
By 50/100, the buffer will return input images.
By 50/100, the buffer will return images previously stored in the buffer,
and insert the current images to the buffer.
"""
if self.pool_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
# image = torch.unsqueeze(image.data, 0)
image = jnp.expand_dims(image, 0)
if (
self.num_imgs < self.pool_size
): # if the buffer is not full; keep inserting current images to the buffer
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if (
p > 0.5
): # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
random_id = random.randint(
0, self.pool_size - 1
) # randint is inclusive
# tmp = self.images[random_id].clone()
tmp = jnp.copy(self.images[random_id])
self.images[random_id] = image
return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image
return_images.append(image)
# return_images = torch.cat(return_images, 0) # collect all the images and return
return_images = jnp.concatenate(
return_images, 0
) # collect all the images and return
return return_images