-
Notifications
You must be signed in to change notification settings - Fork 23
/
demo.py
97 lines (62 loc) · 2.57 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import sys
import argparse
import numpy as np
from PIL import Image
import torch
from torch.autograd import Variable
from torchvision import transforms
import Layout
import Utils
import config as cf
from Model import DuLaNet, E2P
import postproc
parser = argparse.ArgumentParser(description='DuLa-Net inference scripts')
parser.add_argument('--backbone', default='resnet18',
choices=['resnet18', 'resnet34', 'resnet50'], help='backbone network')
parser.add_argument('--ckpt', default='./Model/ckpt/res18_realtor.pkl',
help='path to the model ckpt file')
parser.add_argument('--input', type=str, help='input panorama image')
parser.add_argument('--output', default='./output', type=str, help='output path')
parser.add_argument('--cpu', action='store_true', help='using cpu or not')
parser.add_argument('--seed', default=224, type=int, help='manual random seed')
parser.add_argument('--processes', default=8, type=int, help='processes number')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu
else 'cpu')
def predict(model, input_path):
print('predict')
model.eval()
img = Image.open(input_path).convert("RGB")
trans = transforms.Compose([
transforms.Resize((cf.pano_size)),
transforms.ToTensor()
])
color = torch.unsqueeze(trans(img), 0).to(device)
[fp, fc, h] = model(color)
e2p = E2P(cf.pano_size, cf.fp_size, cf.fp_fov)
[fc_up, fc_down] = e2p(fc)
[fp, fc_up, fc_down, h] = Utils.var2np([fp, fc_up, fc_down, h])
fp_pts, fp_pred = postproc.run(fp, fc_up, fc_down, h)
# Visualization
scene_pred = Layout.pts2scene(fp_pts, h)
edge = Layout.genLayoutEdgeMap(scene_pred, [512 , 1024, 3], dilat=2, blur=0)
img = img.resize((1024,512))
img = np.array(img, np.float32) / 255
vis = img * 0.5 + edge * 0.5
vis = Image.fromarray(np.uint8(vis* 255))
vis.save(os.path.splitext(input_path)[0] + "_vis.jpg")
#Save output 3d layout as json
Layout.saveSceneAsJson(os.path.splitext(input_path)[0] + "_res.json", scene_pred)
return
def demo():
np.random.seed(args.seed)
torch.manual_seed(args.seed)
model = DuLaNet(args.backbone).to(device)
assert args.ckpt is not None, "need pretrained model"
assert args.input, "need an input for prediction"
#model.load_state_dict(torch.load(args.ckpt))
model.load_state_dict(torch.load(args.ckpt, map_location=str(device)))
predict(model, args.input)
if __name__ == '__main__':
demo()