forked from RLado/STB-VMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_test.py
103 lines (83 loc) · 3.68 KB
/
run_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
from PIL import Image
from models.modeldoubleheight import STBVMM
from torchvision.datasets.folder import default_loader
from utils. import ImageFromFolderDouble
def main(args):
# Device choice (auto)
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = args.device
print(f'Using device: {device}')
# Create model
model = STBVMM(img_size=384, patch_size=1, in_chans=3,
embed_dim=192, depths=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6],
window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, img_range=1., resi_connection='1conv',
manipulator_num_resblk=1).to(device)
# Load checkpoint
if os.path.isfile(args.load_ckpt):
print("=> loading checkpoint '{}'".format(args.load_ckpt))
checkpoint = torch.load(args.load_ckpt)
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.load_ckpt, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.load_ckpt))
assert (False)
# Check saving directory
save_dir = args.save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print(save_dir)
# Data loader
dataset_mag = ImageFromFolder(
args.image_path, num_data=args.num_data, preprocessing=False)
data_loader = data.DataLoader(dataset_mag,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=False)
# Generate frames
model.eval()
for i, (xa, xb, xc, xd, mag_factor) in enumerate(data_loader):
xa = xa.to(device)
xb = xb.to(device)
xc = xc.to(device)
xd = xd.to(device)
# forward
with torch.no_grad():
out_a, out_b = model(xa, xb)
out_a = out_a.cpu().numpy()
out_b = out_b.cpu().numpy()
# save output images
for j in range(out_a.shape[0]):
output_a = Image.fromarray(np.uint8(out_a[j] * 255.0))
output_a.save(os.path.join(save_dir, f'output_a_{i * args.batch_size + j}.png'))
output_b = Image.fromarray(np.uint8(out_b[j] * 255.0))
output_b.save(os.path.join(save_dir, f'output_b_{i * args.batch_size + j}.png'))
print("Finished processing all images.")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--load_ckpt', default='ckpt/ckpt_e09.pth.tar', type=str)
parser.add_argument('--save_dir', default='output/', type=str)
parser.add_argument('--image_path', default='input/', type=str)
parser.add_argument('--num_data', default=100, type=int)
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--workers', default=4, type=int)
parser.add_argument('-m', '--mag', metavar='N', default=20.0, type=float,
help='magnification factor (default: 20.0)')
parser.add_argument('--mode', default='static', type=str, choices=['static', 'dynamic'],
help='magnification mode (static, dynamic)')
parser.add_argument('--device', default='auto', type=str)
args = parser.parse_args()
main(args)