-
Notifications
You must be signed in to change notification settings - Fork 24
/
main.py
98 lines (82 loc) · 3.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gym
from Brain import SACAgent
from Common import Play, Logger, get_params
import numpy as np
from tqdm import tqdm
import mujoco_py
def concat_state_latent(s, z_, n):
z_one_hot = np.zeros(n)
z_one_hot[z_] = 1
return np.concatenate([s, z_one_hot])
if __name__ == "__main__":
params = get_params()
test_env = gym.make(params["env_name"])
n_states = test_env.observation_space.shape[0]
n_actions = test_env.action_space.shape[0]
action_bounds = [test_env.action_space.low[0], test_env.action_space.high[0]]
params.update({"n_states": n_states,
"n_actions": n_actions,
"action_bounds": action_bounds})
print("params:", params)
test_env.close()
del test_env, n_states, n_actions, action_bounds
env = gym.make(params["env_name"])
p_z = np.full(params["n_skills"], 1 / params["n_skills"])
agent = SACAgent(p_z=p_z, **params)
logger = Logger(agent, **params)
if params["do_train"]:
if not params["train_from_scratch"]:
episode, last_logq_zs, np_rng_state, *env_rng_states, torch_rng_state, random_rng_state = logger.load_weights()
agent.hard_update_target_network()
min_episode = episode
np.random.set_state(np_rng_state)
env.np_random.set_state(env_rng_states[0])
env.observation_space.np_random.set_state(env_rng_states[1])
env.action_space.np_random.set_state(env_rng_states[2])
agent.set_rng_states(torch_rng_state, random_rng_state)
print("Keep training from previous run.")
else:
min_episode = 0
last_logq_zs = 0
np.random.seed(params["seed"])
env.seed(params["seed"])
env.observation_space.seed(params["seed"])
env.action_space.seed(params["seed"])
print("Training from scratch.")
logger.on()
for episode in tqdm(range(1 + min_episode, params["max_n_episodes"] + 1)):
z = np.random.choice(params["n_skills"], p=p_z)
state = env.reset()
state = concat_state_latent(state, z, params["n_skills"])
episode_reward = 0
logq_zses = []
max_n_steps = min(params["max_episode_len"], env.spec.max_episode_steps)
for step in range(1, 1 + max_n_steps):
action = agent.choose_action(state)
next_state, reward, done, _ = env.step(action)
next_state = concat_state_latent(next_state, z, params["n_skills"])
agent.store(state, z, done, action, next_state)
logq_zs = agent.train()
if logq_zs is None:
logq_zses.append(last_logq_zs)
else:
logq_zses.append(logq_zs)
episode_reward += reward
state = next_state
if done:
break
logger.log(episode,
episode_reward,
z,
sum(logq_zses) / len(logq_zses),
step,
np.random.get_state(),
env.np_random.get_state(),
env.observation_space.np_random.get_state(),
env.action_space.np_random.get_state(),
*agent.get_rng_states(),
)
else:
logger.load_weights()
player = Play(env, agent, n_skills=params["n_skills"])
player.evaluate()