Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
roger-creus committed Jul 17, 2024
1 parent 6e89a5c commit e9f158f
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 298 deletions.
26 changes: 7 additions & 19 deletions cleanrl/pqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from collections import deque
from IPython import embed

@dataclass
class Args:
Expand Down Expand Up @@ -103,27 +102,15 @@ def step(self, action):

# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
def __init__(self, env, norm_type="layer_norm", norm_input=False):
def __init__(self, env):
super().__init__()

# wether to normalize the input or not
if norm_input:
self.norm_in = nn.BatchNorm1d(np.array(env.single_observation_space.shape).prod())

# wether to use layer norm or batch norm for internal layers
if norm_type == "layer_norm":
self.norm = nn.LayerNorm
elif norm_type == "batch_norm":
self.norm = nn.BatchNorm1d
else:
self.norm = lambda x: x

self.network = nn.Sequential(
nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
self.norm(120),
nn.LayerNorm(120),
nn.ReLU(),
nn.Linear(120, 84),
self.norm(84),
nn.LayerNorm(84),
nn.ReLU(),
nn.Linear(84, env.single_action_space.n),
)
Expand Down Expand Up @@ -215,8 +202,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
if random.random() < epsilon:
action = torch.randint(0, envs.single_action_space.n, (envs.num_envs,))
else:
q_values = q_network(next_obs)
action = torch.argmax(q_values, dim=1)
with torch.no_grad():
q_values = q_network(next_obs)
action = torch.argmax(q_values, dim=1)
actions[step] = action

# TRY NOT TO MODIFY: execute the game and log data.
Expand All @@ -233,7 +221,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)

# bootstrap value if not done
# Compute Q(lambda) targets
with torch.no_grad():
returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
Expand Down
279 changes: 0 additions & 279 deletions cleanrl/pqn_atari.py

This file was deleted.

Loading

0 comments on commit e9f158f

Please sign in to comment.