-
Notifications
You must be signed in to change notification settings - Fork 20
/
splitInference.py
148 lines (125 loc) · 4.59 KB
/
splitInference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from config import nDep as gen_ls
import torch
import torch.nn as nn
import sys
import gc
nUp_n = 2 ** gen_ls ##noise and image tensor relation
nUp = 2**(gen_ls+1) ##for buffer ovelap calculation, >=nUp_n
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def splitH(Im_tri, noise,te,f,sH):
S = Im_tri.shape[3]
orig = 0 * Im_tri[:, :3]
if te is not None:
alpha = 0 * Im_tri[:, :3]
origM=orig*0
blenda = 0 * Im_tri[:, :1].repeat(1, te.shape[1], 1, 1)##te.shape[1] is the template count N
else:
alpha=None
origM=None
blenda=None
err = 0
increment = []
oldS = 0
s=sH
for i in range(1, s + 1):
S1 = int(i / float(s) * S)
S1 = S1 - S1 % nUp
increment += [(oldS, S1)]
oldS = S1
print ("H increment", increment,"image",Im_tri.shape,"noise",noise.shape)
def _proc(incr):
if incr[0] == 0:
li = 0
lz = 0
else:
li = incr[0] - 4 * nUp
lz = li//nUp_n#incr[0] // nUp - 4
if incr[1] == S:
ui = S
uz = S // nUp_n
else:
ui = incr[1] + 4 * nUp
uz = ui//nUp_n#incr[1] // nUp + 4
#print (li, ui, "z", lz, uz,"incr", incr,)
Im_tri1 = Im_tri[:, :, :, li:ui]
noise1 = noise[:, :, :, lz:uz]
Im_tri1 = Im_tri1.to(device)
#print ("H values",Im_tri1.shape,noise1.shape,"idxes",li,ui,lz,uz)
if te is not None:
te1 = te[:, :, :, :, li:ui].float()
te1=te1.to(device)##more engineered but efficient in code: only add template chunk to memory, not full large template batch
else:
te1=None
gen1,alpha1,blenda1,mix1 = f(Im_tri1, noise1,te1,True)
#print ("setting indices", gen1[:, :, :, incr[0] - li:incr[1] - li].shape, orig[:, :, :, incr[0]:incr[1]].shape)
orig[:, :, :, incr[0]:incr[1]] = gen1[:, :, :, incr[0] - li:incr[1] - li]
if te is not None:
alpha[:, :, :, incr[0]:incr[1]] = alpha1[:, :, :, incr[0] - li:incr[1] - li]
origM[:, :, :, incr[0]:incr[1]] = mix1[:, :, :, incr[0] - li:incr[1] - li]
blenda[:, :, :, incr[0]:incr[1]] =blenda1[:, :, :, incr[0] - li:incr[1] - li]
gc.collect()##hmm, doe not help
torch.cuda.empty_cache()
return 0#error_full_1
for incr in increment:
err += _proc(incr)
sys.stdout.flush()
return orig,alpha,blenda,origM, err # error_full_1+error_full_2
def splitW(Im_tri, noise,te,f):
##careful with size: too small and will run out of memory; too large and will have empty slice and cause bug
##some rough heuristic how to choose size
sW = Im_tri.shape[2]//480
sH = Im_tri.shape[3]//480
print ("generated split ratios",sW,sH)
S = Im_tri.shape[2]
##4 image buffers
orig = 0 * Im_tri[:, :3]
if te is not None:
alpha=0 * Im_tri[:, :3]
origM=orig*0
blenda = 0*Im_tri[:, :1].repeat(1, te.shape[1], 1, 1)##te.shape[1] is the template count
else:
alpha = None
origM = None
blenda = None
err = 0
increment = []
s = sW
oldS = 0
for i in range(1, s + 1):
S1 = int(i / float(s) * S)
S1 = S1 - S1 % nUp
increment += [(oldS, S1)]
oldS = S1
print ("W increment", increment,"image",Im_tri.shape,"noise",noise.shape)
def _proc(incr):
if incr[0] == 0:
li = 0
lz = 0
else:
li = incr[0] - 4 * nUp
lz = li//nUp_n#incr[0] // nUp - 4
if incr[1] == S:
ui = S
uz = S // nUp_n
else:
ui = incr[1] + 4 * nUp
uz = ui//nUp_n#incr[1] // nUp + 4
#print (li, ui, "z", lz, uz)
#print ("incr", incr,)
Im_tri1 = Im_tri[:, :, li:ui]
noise1 = noise[:, :, lz:uz]
if te is not None:
te1 = te[:, :, :, li:ui]
else:
te1=None
gen1,alpha1,blenda1,mix1, error_full_1 = splitH(Im_tri1, noise1,te1,f,sH)
#print ("setting indices", gen1[:, :, incr[0] - li:incr[1] - li].shape, orig[:, :, incr[0]:incr[1]].shape)
orig[:, :, incr[0]:incr[1]] = gen1[:, :, incr[0] - li:incr[1] - li]
if te is not None:
alpha[:, :, incr[0]:incr[1]] = alpha1[:, :, incr[0] - li:incr[1] - li]
blenda[:, :,incr[0]:incr[1]] = blenda1[:, :, incr[0] - li:incr[1] - li]
origM[:, :, incr[0]:incr[1]] = mix1[:, :, incr[0] - li:incr[1] - li]
return error_full_1
for incr in increment:
err += _proc(incr)
return orig,alpha,blenda,origM