Skip to content

Commit

Permalink
Separate Implement for Handling Trunction with PPO
Browse files Browse the repository at this point in the history
  • Loading branch information
Anonymous committed May 29, 2024
1 parent 31a91a2 commit 625aba7
Show file tree
Hide file tree
Showing 2 changed files with 388 additions and 35 deletions.
57 changes: 22 additions & 35 deletions cleanrl/ppo_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,21 +184,18 @@ def get_action_and_value(self, x, action=None):

# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
next_obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
next_dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
next_terminations = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)

# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
next_ob, _ = envs.reset(seed=args.seed)
next_ob = torch.Tensor(next_ob).to(device)
next_obs, _ = envs.reset(seed=args.seed)
next_obs = torch.Tensor(next_obs).to(device)
next_done = torch.zeros(args.num_envs).to(device)
next_termination = torch.zeros(args.num_envs).to(device)

for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
Expand All @@ -209,53 +206,43 @@ def get_action_and_value(self, x, action=None):

for step in range(0, args.num_steps):
global_step += args.num_envs
obs[step] = next_obs
dones[step] = next_done

ob = next_ob
# ALGO LOGIC: action logic
with torch.no_grad():
action, logprob, _, value = agent.get_action_and_value(ob)
action, logprob, _, value = agent.get_action_and_value(next_obs)
values[step] = value.flatten()
actions[step] = action
logprobs[step] = logprob

# TRY NOT TO MODIFY: execute the game and log data.
next_ob, reward, next_termination, next_truncation, info = envs.step(action.cpu().numpy())

# Correct next obervation (for vec gym)
real_next_ob = next_ob.copy()
for idx, trunc in enumerate(next_truncation):
if trunc:
real_next_ob[idx] = info["final_observation"][idx]
next_ob = torch.Tensor(next_ob).to(device)

# Collect trajectory
obs[step] = torch.Tensor(ob).to(device)
next_obs[step] = torch.Tensor(real_next_ob).to(device)
actions[step] = torch.Tensor(action).to(device)
logprobs[step] = torch.Tensor(logprob).to(device)
values[step] = torch.Tensor(value.flatten()).to(device)
next_terminations[step] = torch.Tensor(next_termination).to(device)
next_dones[step] = torch.Tensor(np.logical_or(next_termination, next_truncation)).to(device)
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
next_done = np.logical_or(terminations, truncations)
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

if "final_info" in info:
for info in info["final_info"]:
if "final_info" in infos:
for info in infos["final_info"]:
if info and "episode" in info:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

# bootstrap value if not done
with torch.no_grad():
next_values = torch.zeros_like(values[0]).to(device)
next_value = agent.get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
next_values = agent.get_value(next_obs[t]).flatten()
nextnonterminal = 1.0 - next_done
nextvalues = next_value
else:
value_mask = next_dones[t].bool()
next_values[value_mask] = agent.get_value(next_obs[t][value_mask]).flatten()
next_values[~value_mask] = values[t + 1][~value_mask]
delta = rewards[t] + args.gamma * next_values * (1 - next_terminations[t]) - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * (1 - next_dones[t]) * lastgaelam
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values

# flatten the batch
Expand Down Expand Up @@ -363,4 +350,4 @@ def get_action_and_value(self, x, action=None):
push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
writer.close()
Loading

0 comments on commit 625aba7

Please sign in to comment.