-
Notifications
You must be signed in to change notification settings - Fork 49
/
inference.py
119 lines (95 loc) · 3.64 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# -*- coding: utf-8 -*-
# @Time : 2020-02-26 17:53
# @Author : Zonas
# @Email : zonas.wang@gmail.com
# @File : inference.py
"""
"""
import argparse
import logging
import os
import os.path as osp
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from unet import NestedUNet
from unet import UNet
from utils.dataset import BasicDataset
from config import UNetConfig
cfg = UNetConfig()
def inference_one(net, image, device):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(image, cfg.scale))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
with torch.no_grad():
output = net(img)
if cfg.deepsupervision:
output = output[-1]
if cfg.n_classes > 1:
probs = F.softmax(output, dim=1)
else:
probs = torch.sigmoid(output)
probs = probs.squeeze(0)
tf = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((image.size[1], image.size[0])),
transforms.ToTensor()
]
)
if cfg.n_classes == 1:
probs = tf(probs.cpu())
mask = probs.squeeze().cpu().numpy()
return mask > cfg.out_threshold
else:
masks = []
for prob in probs:
prob = tf(prob.cpu())
mask = prob.squeeze().cpu().numpy()
mask = mask > cfg.out_threshold
masks.append(mask)
return masks
def get_args():
parser = argparse.ArgumentParser(description='Predict masks from input images',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model', '-m', default='MODEL.pth',
metavar='FILE',
help="Specify the file in which the model is stored")
parser.add_argument('--input', '-i', dest='input', type=str, default='',
help='Directory of input images')
parser.add_argument('--output', '-o', dest='output', type=str, default='',
help='Directory of ouput images')
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
input_imgs = os.listdir(args.input)
net = eval(cfg.model)(cfg)
logging.info("Loading model {}".format(args.model))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net.to(device=device)
net.load_state_dict(torch.load(args.model, map_location=device))
logging.info("Model loaded !")
for i, img_name in tqdm(enumerate(input_imgs)):
logging.info("\nPredicting image {} ...".format(img_name))
img_path = osp.join(args.input, img_name)
img = Image.open(img_path)
mask = inference_one(net=net,
image=img,
device=device)
img_name_no_ext = osp.splitext(img_name)[0]
output_img_dir = osp.join(args.output, img_name_no_ext)
os.makedirs(output_img_dir, exist_ok=True)
if cfg.n_classes == 1:
image_idx = Image.fromarray((mask * 255).astype(np.uint8))
image_idx.save(osp.join(output_img_dir, img_name))
else:
for idx in range(0, len(mask)):
img_name_idx = img_name_no_ext + "_" + str(idx) + ".png"
image_idx = Image.fromarray((mask[idx] * 255).astype(np.uint8))
image_idx.save(osp.join(output_img_dir, img_name_idx))