Skip to content

Commit

Permalink
update next_obs, rewards, terminations, truncations, infos
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpkjc committed Oct 15, 2023
1 parent ca3d0ce commit d1d28c7
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 56 deletions.
11 changes: 6 additions & 5 deletions cleanrl/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = actions.cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -208,13 +208,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down Expand Up @@ -257,7 +258,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
loss.backward()
optimizer.step()

# update the target network
# update target network
if global_step % args.target_network_frequency == 0:
target_network.load_state_dict(q_network.state_dict())

Expand Down
14 changes: 8 additions & 6 deletions cleanrl/c51_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def thunk():
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)

env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
Expand All @@ -104,8 +105,8 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.action_space.seed(seed)

env.action_space.seed(seed)
return env

return thunk
Expand Down Expand Up @@ -218,7 +219,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = actions.cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -230,13 +231,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down Expand Up @@ -279,7 +281,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
loss.backward()
optimizer.step()

# update the target network
# update target network
if global_step % args.target_network_frequency == 0:
target_network.load_state_dict(q_network.state_dict())

Expand Down
14 changes: 8 additions & 6 deletions cleanrl/c51_atari_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def thunk():
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)

env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
Expand All @@ -107,8 +108,8 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.action_space.seed(seed)

env.action_space.seed(seed)
return env

return thunk
Expand Down Expand Up @@ -278,7 +279,7 @@ def get_action(q_state, obs):
actions = jax.device_get(actions)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -290,13 +291,14 @@ def get_action(q_state, obs):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand All @@ -319,7 +321,7 @@ def get_action(q_state, obs):
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# update the target network
# update target network
if global_step % args.target_network_frequency == 0:
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))

Expand Down
11 changes: 6 additions & 5 deletions cleanrl/c51_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def loss(q_params, observations, actions, target_pmfs):
actions = jax.device_get(actions)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -254,13 +254,14 @@ def loss(q_params, observations, actions, target_pmfs):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand All @@ -283,7 +284,7 @@ def loss(q_params, observations, actions, target_pmfs):
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# update the target network
# update target network
if global_step % args.target_network_frequency == 0:
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))

Expand Down
9 changes: 5 additions & 4 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -195,13 +195,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
12 changes: 7 additions & 5 deletions cleanrl/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def thunk():
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)

env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
Expand All @@ -101,8 +102,8 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.action_space.seed(seed)

env.action_space.seed(seed)
return env

return thunk
Expand Down Expand Up @@ -205,7 +206,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -217,13 +218,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
12 changes: 7 additions & 5 deletions cleanrl/dqn_atari_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def thunk():
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)

env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
Expand All @@ -103,8 +104,8 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.action_space.seed(seed)

env.action_space.seed(seed)
return env

return thunk
Expand Down Expand Up @@ -236,7 +237,7 @@ def mse_loss(params):
actions = jax.device_get(actions)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -248,13 +249,14 @@ def mse_loss(params):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
11 changes: 5 additions & 6 deletions cleanrl/dqn_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

obs, _ = envs.reset(seed=args.seed)

q_network = QNetwork(action_dim=envs.single_action_space.n)

q_state = TrainState.create(
apply_fn=q_network.apply,
params=q_network.init(q_key, obs),
Expand Down Expand Up @@ -208,7 +206,7 @@ def mse_loss(params):
actions = jax.device_get(actions)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -220,13 +218,14 @@ def mse_loss(params):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
4 changes: 2 additions & 2 deletions cleanrl/ppo_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def get_action_and_value(self, x, action=None):
logprobs[step] = logprob

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy())
done = np.logical_or(terminated, truncated)
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
done = np.logical_or(terminations, truncations)
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)

Expand Down
16 changes: 8 additions & 8 deletions cleanrl/qdagger_dqn_atari_impalacnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,12 @@ def kl_divergence_with_logits(target_logits, prediction_logits):
else:
q_values = teacher_model(torch.Tensor(obs).to(device))
actions = torch.argmax(q_values, dim=1).cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
teacher_rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
teacher_rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
obs = next_obs

# offline training phase: train the student model using the qdagger loss
Expand Down Expand Up @@ -377,7 +377,7 @@ def kl_divergence_with_logits(target_logits, prediction_logits):
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -394,10 +394,10 @@ def kl_divergence_with_logits(target_logits, prediction_logits):

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
4 changes: 2 additions & 2 deletions cleanrl/rpo_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ def get_action_and_value(self, x, action=None):
logprobs[step] = logprob

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy())
done = np.logical_or(terminated, truncated)
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
done = np.logical_or(terminations, truncations)
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)

Expand Down
Loading

0 comments on commit d1d28c7

Please sign in to comment.