Skip to content

Commit

Permalink
default rollout hyperparameters in PQN are now equal to PPO. same for…
Browse files Browse the repository at this point in the history
… PQN Lstm. Ran validation experiments. Reverting undesired changes in DQN. WIP documentation
  • Loading branch information
roger-creus committed Oct 25, 2024
1 parent cc66288 commit 10334fa
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 99 deletions.
4 changes: 1 addition & 3 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ class Args:
"""timestep to start learning"""
train_frequency: int = 10
"""the frequency of training"""
rew_scale: float = 0.1
"""the reward scaling factor"""


def make_env(env_id, seed, idx, capture_video, run_name):
Expand Down Expand Up @@ -195,7 +193,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards * args.rew_scale, terminations, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
117 changes: 45 additions & 72 deletions cleanrl/pqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
import os
import random
import time
from collections import deque
from dataclasses import dataclass

import envpool
import gym
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -42,13 +40,13 @@ class Args:
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 32
num_envs: int = 4
"""the number of parallel game environments"""
num_steps: int = 64
num_steps: int = 128
"""the number of steps to run for each environment per update"""
num_minibatches: int = 16
num_minibatches: int = 4
"""the number of mini-batches"""
update_epochs: int = 2
update_epochs: int = 4
"""the K epochs to update the policy"""
anneal_lr: bool = True
"""Toggle learning rate annealing"""
Expand All @@ -62,44 +60,29 @@ class Args:
"""the fraction of `total_timesteps` it takes from start_e to end_e"""
max_grad_norm: float = 10.0
"""the maximum norm for the gradient clipping"""
rew_scale: float = 0.1
"""the reward scaling factor"""
q_lambda: float = 0.65
"""the lambda for Q(lambda)"""


class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None

def reset(self, **kwargs):
observations = super().reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.lives = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations

def step(self, action):
observations, rewards, dones, infos = super().step(action)
self.episode_returns += rewards
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
observations,
rewards,
dones,
infos,
)
def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env.action_space.seed(seed)

return env

return thunk


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer


# ALGO LOGIC: initialize agent here:
Expand All @@ -108,13 +91,13 @@ def __init__(self, env):
super().__init__()

self.network = nn.Sequential(
nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
layer_init(nn.Linear(np.array(env.single_observation_space.shape).prod(), 120)),
nn.LayerNorm(120),
nn.ReLU(),
nn.Linear(120, 84),
layer_init(nn.Linear(120, 84)),
nn.LayerNorm(84),
nn.ReLU(),
nn.Linear(84, env.single_action_space.n),
layer_init(nn.Linear(84, env.single_action_space.n)),
)

def forward(self, x):
Expand Down Expand Up @@ -159,35 +142,27 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = envpool.make(
args.env_id,
env_type="gym",
num_envs=args.num_envs,
seed=args.seed,
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
envs.num_envs = args.num_envs
envs.single_action_space = envs.action_space
envs.single_observation_space = envs.observation_space
envs = RecordEpisodeStatistics(envs)
assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported"
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

# agent setup
q_network = QNetwork(envs).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
optimizer = optim.RAdam(q_network.parameters(), lr=args.learning_rate)

# storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
avg_returns = deque(maxlen=20)

# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()

# TRY NOT TO MODIFY: start the game
next_obs = torch.Tensor(envs.reset()).to(device)
next_obs, _ = envs.reset(seed=args.seed)
next_obs = torch.Tensor(next_obs).to(device)
next_done = torch.zeros(args.num_envs).to(device)

for iteration in range(1, args.num_iterations + 1):
Expand All @@ -203,30 +178,28 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
dones[step] = next_done

epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)

random_actions = torch.randint(0, envs.single_action_space.n, (args.num_envs,)).to(device)
with torch.no_grad():
q_values = q_network(next_obs)
max_actions = torch.argmax(q_values, dim=1)
values[step] = q_values[torch.arange(args.num_envs), max_actions].flatten()

explore = (torch.rand((args.num_envs,)).to(device) < epsilon)
explore = torch.rand((args.num_envs,)).to(device) < epsilon
action = torch.where(explore, random_actions, max_actions)
actions[step] = action

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, next_done, info = envs.step(action.cpu().numpy())
rewards[step] = torch.tensor(reward).to(device).view(-1) * args.rew_scale
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
next_done = np.logical_or(terminations, truncations)
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for idx, d in enumerate(next_done):
if d:
print(f"global_step={global_step}, episodic_return={info['r'][idx]}")
avg_returns.append(info["r"][idx])
writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step)
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
if "final_info" in infos:
for info in infos["final_info"]:
if info and "episode" in info:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

# Compute Q(lambda) targets
with torch.no_grad():
Expand All @@ -242,7 +215,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
returns[t] = rewards[t] + args.gamma * (
args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value * nextnonterminal
)

# flatten the batch
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
Expand Down
30 changes: 18 additions & 12 deletions cleanrl/pqn_atari_envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ class Args:
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 128
num_envs: int = 8
"""the number of parallel game environments"""
num_steps: int = 32
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.99
"""the discount factor gamma"""
num_minibatches: int = 32
num_minibatches: int = 4
"""the number of mini-batches"""
update_epochs: int = 2
update_epochs: int = 4
"""the K epochs to update the policy"""
max_grad_norm: float = 10.0
"""the maximum norm for the gradient clipping"""
Expand Down Expand Up @@ -107,25 +107,31 @@ def step(self, action):
infos,
)

# ALGO LOGIC: initialize agent here:

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer


class QNetwork(nn.Module):
def __init__(self, env):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(4, 32, 8, stride=4),
layer_init(nn.Conv2d(4, 32, 8, stride=4)),
nn.LayerNorm([32, 20, 20]),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
layer_init(nn.Conv2d(32, 64, 4, stride=2)),
nn.LayerNorm([64, 9, 9]),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1),
layer_init(nn.Conv2d(64, 64, 3, stride=1)),
nn.LayerNorm([64, 7, 7]),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, 512),
layer_init(nn.Linear(3136, 512)),
nn.LayerNorm(512),
nn.ReLU(),
nn.Linear(512, env.single_action_space.n),
layer_init(nn.Linear(512, env.single_action_space.n)),
)

def forward(self, x):
Expand Down Expand Up @@ -185,7 +191,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported"

q_network = QNetwork(envs).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
optimizer = optim.RAdam(q_network.parameters(), lr=args.learning_rate)

# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
Expand Down Expand Up @@ -221,7 +227,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
max_actions = torch.argmax(q_values, dim=1)
values[step] = q_values[torch.arange(args.num_envs), max_actions].flatten()

explore = (torch.rand((args.num_envs,)).to(device) < epsilon)
explore = torch.rand((args.num_envs,)).to(device) < epsilon
action = torch.where(explore, random_actions, max_actions)
actions[step] = action

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import tyro
from torch.utils.tensorboard import SummaryWriter


@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
Expand Down Expand Up @@ -52,7 +51,7 @@ class Args:
"""the discount factor gamma"""
num_minibatches: int = 4
"""the number of mini-batches"""
update_epochs: int = 2
update_epochs: int = 4
"""the K epochs to update the policy"""
max_grad_norm: float = 0.5
"""the maximum norm for the gradient clipping"""
Expand Down Expand Up @@ -108,29 +107,38 @@ def step(self, action):
)


# ALGO LOGIC: initialize agent here:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer


class QNetwork(nn.Module):
def __init__(self, env):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(1, 32, 8, stride=4),
layer_init(nn.Conv2d(1, 32, 8, stride=4)),
nn.LayerNorm([32, 20, 20]),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
layer_init(nn.Conv2d(32, 64, 4, stride=2)),
nn.LayerNorm([64, 9, 9]),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1),
layer_init(nn.Conv2d(64, 64, 3, stride=1)),
nn.LayerNorm([64, 7, 7]),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, 512),
layer_init(nn.Linear(3136, 512)),
nn.LayerNorm(512),
nn.ReLU(),
)
self.lstm = nn.LSTM(512, 128)
self.head = nn.Linear(128, env.single_action_space.n)
for name, param in self.lstm.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0)
elif "weight" in name:
nn.init.orthogonal_(param, 1.0)
self.q_func = layer_init(nn.Linear(128, env.single_action_space.n))


def get_states(self, x, lstm_state, done):
hidden = self.network(x / 255.0)

Expand All @@ -153,7 +161,7 @@ def get_states(self, x, lstm_state, done):

def forward(self, x, lstm_state, done):
hidden, lstm_state = self.get_states(x, lstm_state, done)
return self.head(hidden), lstm_state
return self.q_func(hidden), lstm_state


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
Expand Down Expand Up @@ -209,7 +217,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported"

q_network = QNetwork(envs).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
optimizer = optim.RAdam(q_network.parameters(), lr=args.learning_rate)

# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
Expand All @@ -219,7 +227,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
avg_returns = deque(maxlen=20)


# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
Expand Down
Loading

0 comments on commit 10334fa

Please sign in to comment.