-
Notifications
You must be signed in to change notification settings - Fork 1
/
helper.py
99 lines (82 loc) · 3.05 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
def as_one_hot(make_one_hot, num):
out = np.zeros(num, dtype=np.int)
out[make_one_hot] = 1
return out
def one_hot_to_id(one_hot):
for idx in range(len(one_hot)):
if one_hot[idx] == 1:
return idx
return -1
def get_rotated_quadrant(quad, state, size):
if 1 <= quad <= 4:
if isinstance(state, str):
state = load_state(state)
out = np.zeros_like(state[0])
half_g = int(size)
if quad == 1:
out = np.array(state[0])
elif quad == 2:
for x in range(0, half_g):
for y in range(0, half_g):
out[x][y] = state[0][y][half_g - x - 1]
new_walls = np.zeros(4)
for i in range(4):
if out[x][y][i] == 1:
if i + 1 < 4:
new_walls[i + 1] = 1
else:
new_walls[i - 3] = 1
out[x][y][:4] = new_walls
elif quad == 3:
for x in range(0, half_g):
for y in range(0, half_g):
out[x][y] = state[0][half_g - x - 1][half_g - y - 1]
new_walls = np.zeros(4)
for i in range(4):
if out[x][y][i] == 1:
if i + 2 < 4:
new_walls[i + 2] = 1
else:
new_walls[i - 2] = 1
out[x][y][:4] = new_walls
elif quad == 4:
for x in range(0, half_g):
for y in range(0, half_g):
out[x][y] = state[0][half_g - y - 1][x]
new_walls = np.zeros(4)
for i in range(4):
if out[x][y][i] == 1:
if i + 3 < 4:
new_walls[i + 3] = 1
else:
new_walls[i - 1] = 1
out[x][y][:4] = new_walls
return out
def combine_quadrants(file_name):
one = load_state(file_name + "_0.npy")
two = load_state(file_name + "_1.npy")
lis = [one, two]
np.save(file_name, np.array(lis))
def load_state(filename):
return np.load(filename)
def to_wrkdir():
import os
try:
print(os.getcwd())
os.chdir("/home/nic/Dokumente/ricochet")
print(os.getcwd())
except FileNotFoundError:
print("error")
pass
"""for i in range(8):
for j in range(0):
states = load_state("quadrants/pre_"+str(i)+"_"+str(j)+".npy")
lis = [states[0], get_rotated_quadrant(0, states, 8), get_rotated_quadrant(3, states, 8), get_rotated_quadrant(4, states, 8)]
np.save("quadrants/pre_"+str(i)+"_"+str(j)+"", np.array(lis))"""
"""for i in range(8):
combine_quadrants("quadrants/pre_"+str(i))"""
"""lis = []
for i in range(8):
lis.append(load_state("quadrants/pre_" + str(i) + ".npy"))
np.save("quadrants/pre_all", np.array(lis))"""