Skip to content

Commit

Permalink
advantage normalization is off per default, added hidden layer after …
Browse files Browse the repository at this point in the history
…TrXL
  • Loading branch information
MarcoMeter committed Jun 25, 2024
1 parent 8ff1edf commit b9f8131
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion cleanrl/ppo_trxl/ppo_trxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Args:
"""the number of mini-batches"""
update_epochs: int = 3
"""the K epochs to update the policy"""
norm_adv: bool = True
norm_adv: bool = False
"""Toggles advantages normalization"""
clip_coef: float = 0.1
"""the surrogate clipping coefficient"""
Expand Down Expand Up @@ -273,6 +273,11 @@ def __init__(self, args, observation_space, action_space_shape, max_episode_step
args.trxl_num_blocks, args.trxl_dim, args.trxl_num_heads, self.max_episode_steps, args.trxl_positional_encoding
)

self.hidden_post_trxl = nn.Sequential(
layer_init(nn.Linear(args.trxl_dim, args.trxl_dim)),
nn.ReLU(),
)

self.actor_branches = nn.ModuleList(
[
layer_init(nn.Linear(args.trxl_dim, out_features=num_actions), np.sqrt(0.01))
Expand Down Expand Up @@ -300,6 +305,7 @@ def get_value(self, x, memory, memory_mask, memory_indices):
else:
x = self.encoder(x)
x, _ = self.transformer(x, memory, memory_mask, memory_indices)
x = self.hidden_post_trxl(x)
return self.critic(x).flatten()

def get_action_and_value(self, x, memory, memory_mask, memory_indices, action=None):
Expand All @@ -308,6 +314,7 @@ def get_action_and_value(self, x, memory, memory_mask, memory_indices, action=No
else:
x = self.encoder(x)
x, memory = self.transformer(x, memory, memory_mask, memory_indices)
x = self.hidden_post_trxl(x)
self.x = x
probs = [Categorical(logits=branch(x)) for branch in self.actor_branches]
if action is None:
Expand Down

0 comments on commit b9f8131

Please sign in to comment.