-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer.py
executable file
·41 lines (27 loc) · 915 Bytes
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import argparse
import numpy as np
import model
import dataset
def get_args():
parser = argparse.ArgumentParser(description="Eval options")
parser.add_argument("--model_path", type=str, default="./models/exp1/weights_50",
help="Path to save models")
return parser.parse_args()
def make_some_noise(batch_size):
return torch.rand(batch_size, 100)
def main():
args = get_args()
models = {}
models['generator'] = model.Generator()
models['generator'].load_state_dict(torch.load(args.model_path+'/generator.pth'))
models['generator'].eval()
store= []
with torch.no_grad():
for _ in range(100):
random_input = make_some_noise(1)
coeff = models['generator'](random_input)
store.append(coeff.numpy()/10)
store = np.array(store)
np.save("generated_coeff.npy", store)
main()