-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualize.py
95 lines (75 loc) · 2.61 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import matplotlib.pyplot as plt
import time
import random
import tensorflow as tf
from model import Model
from loader import TestLoader
import numpy as np
import math
MAX_BATCH_SIZE = 16
def eulerAnglesToRotationMatrix(theta) :
R_x = np.array([[1, 0, 0 ],
[0, math.cos(theta[0]), -math.sin(theta[0]) ],
[0, math.sin(theta[0]), math.cos(theta[0]) ]
])
R_y = np.array([[math.cos(theta[1]), 0, math.sin(theta[1]) ],
[0, 1, 0 ],
[-math.sin(theta[1]), 0, math.cos(theta[1]) ]
])
R_z = np.array([[math.cos(theta[2]), -math.sin(theta[2]), 0],
[math.sin(theta[2]), math.cos(theta[2]), 0],
[0, 0, 1]
])
R = np.dot(R_z, np.dot( R_y, R_x ))
return R
def visualize(model, sess, pred, x, training, sequence):
xdata = []
ydata = []
xdatatrue = []
ydatatrue = []
plt.show()
axes = plt.gca()
axes.set_xlim(-300, 300)
axes.set_ylim(-50, 550)
line, = axes.plot(xdata, ydata, 'r-')
line2, = axes.plot(xdata, ydata, 'b-')
data = TestLoader(sequence)
truth = data.get_truth()
for tru in truth:
trans = np.reshape(tru,(3,4))[:3,3]
xdatatrue.append(trans[0])
ydatatrue.append(trans[2])
line2.set_xdata(xdatatrue)
line2.set_ydata(ydatatrue)
last = truth[:4]
dat = data.get_test(MAX_BATCH_SIZE)
# last = np.eye(4)
plot_numbers=[[],[],[],[],[],[]]
count = 0
while dat is not None:
print(dat.shape)
vec = model.predict(sess, pred, x, training, dat)
for v in vec:
for i in range(6):
plot_numbers[i].append(v[i])
print(v)
count += 1
d_transl = v[:3]
d_rot_mat = eulerAnglesToRotationMatrix(v[3:])
d_transl = np.expand_dims(d_transl, axis=1)
mat = np.hstack([d_rot_mat,d_transl])
mat = np.vstack([mat,[0,0,0,1]])
next = np.matmul(last[-4], np.linalg.inv(mat))
last.append(next)
xdata.append(next[0,3])
ydata.append(next[2,3])
line.set_xdata(xdata)
line.set_ydata(ydata)
plt.draw()
plt.pause(1e-17)
time.sleep(0.01)
dat = data.get_test(MAX_BATCH_SIZE)
#for i in range(6):
# plt.plot(plot_numbers[i])
# plt.show()
plt.show()