-
Notifications
You must be signed in to change notification settings - Fork 415
/
generate.lua
98 lines (86 loc) · 3.42 KB
/
generate.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
require 'image'
require 'nn'
local optnet = require 'optnet'
torch.setdefaulttensortype('torch.FloatTensor')
opt = {
batchSize = 32, -- number of samples to produce
noisetype = 'normal', -- type of noise distribution (uniform / normal).
net = '', -- path to the generator network
imsize = 1, -- used to produce larger images. 1 = 64px. 2 = 80px, 3 = 96px, ...
noisemode = 'random', -- random / line / linefull1d / linefull
name = 'generation1', -- name of the file saved
gpu = 1, -- gpu mode. 0 = CPU, 1 = GPU
display = 1, -- Display image: 0 = false, 1 = true
nz = 100,
}
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
print(opt)
if opt.display == 0 then opt.display = false end
assert(net ~= '', 'provide a generator model')
if opt.gpu > 0 then
require 'cunn'
require 'cudnn'
end
noise = torch.Tensor(opt.batchSize, opt.nz, opt.imsize, opt.imsize)
net = torch.load(opt.net)
-- for older models, there was nn.View on the top
-- which is unnecessary, and hinders convolutional generations.
if torch.type(net:get(1)) == 'nn.View' then
net:remove(1)
end
print(net)
if opt.noisetype == 'uniform' then
noise:uniform(-1, 1)
elseif opt.noisetype == 'normal' then
noise:normal(0, 1)
end
noiseL = torch.FloatTensor(opt.nz):uniform(-1, 1)
noiseR = torch.FloatTensor(opt.nz):uniform(-1, 1)
if opt.noisemode == 'line' then
-- do a linear interpolation in Z space between point A and point B
-- each sample in the mini-batch is a point on the line
line = torch.linspace(0, 1, opt.batchSize)
for i = 1, opt.batchSize do
noise:select(1, i):copy(noiseL * line[i] + noiseR * (1 - line[i]))
end
elseif opt.noisemode == 'linefull1d' then
-- do a linear interpolation in Z space between point A and point B
-- however, generate the samples convolutionally, so a giant image is produced
assert(opt.batchSize == 1, 'for linefull1d mode, give batchSize(1) and imsize > 1')
noise = noise:narrow(3, 1, 1):clone()
line = torch.linspace(0, 1, opt.imsize)
for i = 1, opt.imsize do
noise:narrow(4, i, 1):copy(noiseL * line[i] + noiseR * (1 - line[i]))
end
elseif opt.noisemode == 'linefull' then
-- just like linefull1d above, but try to do it in 2D
assert(opt.batchSize == 1, 'for linefull mode, give batchSize(1) and imsize > 1')
line = torch.linspace(0, 1, opt.imsize)
for i = 1, opt.imsize do
noise:narrow(3, i, 1):narrow(4, i, 1):copy(noiseL * line[i] + noiseR * (1 - line[i]))
end
end
local sample_input = torch.randn(2,100,1,1)
if opt.gpu > 0 then
net:cuda()
cudnn.convert(net, cudnn)
noise = noise:cuda()
sample_input = sample_input:cuda()
else
sample_input = sample_input:float()
net:float()
end
-- a function to setup double-buffering across the network.
-- this drastically reduces the memory needed to generate samples
optnet.optimizeMemory(net, sample_input)
local images = net:forward(noise)
print('Images size: ', images:size(1)..' x '..images:size(2) ..' x '..images:size(3)..' x '..images:size(4))
images:add(1):mul(0.5)
print('Min, Max, Mean, Stdv', images:min(), images:max(), images:mean(), images:std())
image.save(opt.name .. '.png', image.toDisplayTensor(images))
print('Saved image to: ', opt.name .. '.png')
if opt.display then
disp = require 'display'
disp.image(images)
print('Displayed image')
end