From b9f81311489307c630217ec91bc9d4d7603ec7a0 Mon Sep 17 00:00:00 2001 From: Marco Pleines Date: Tue, 25 Jun 2024 13:10:03 +0200 Subject: [PATCH] advantage normalization is off per default, added hidden layer after TrXL --- cleanrl/ppo_trxl/ppo_trxl.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/cleanrl/ppo_trxl/ppo_trxl.py b/cleanrl/ppo_trxl/ppo_trxl.py index 3a556d74..a1ea9cec 100644 --- a/cleanrl/ppo_trxl/ppo_trxl.py +++ b/cleanrl/ppo_trxl/ppo_trxl.py @@ -62,7 +62,7 @@ class Args: """the number of mini-batches""" update_epochs: int = 3 """the K epochs to update the policy""" - norm_adv: bool = True + norm_adv: bool = False """Toggles advantages normalization""" clip_coef: float = 0.1 """the surrogate clipping coefficient""" @@ -273,6 +273,11 @@ def __init__(self, args, observation_space, action_space_shape, max_episode_step args.trxl_num_blocks, args.trxl_dim, args.trxl_num_heads, self.max_episode_steps, args.trxl_positional_encoding ) + self.hidden_post_trxl = nn.Sequential( + layer_init(nn.Linear(args.trxl_dim, args.trxl_dim)), + nn.ReLU(), + ) + self.actor_branches = nn.ModuleList( [ layer_init(nn.Linear(args.trxl_dim, out_features=num_actions), np.sqrt(0.01)) @@ -300,6 +305,7 @@ def get_value(self, x, memory, memory_mask, memory_indices): else: x = self.encoder(x) x, _ = self.transformer(x, memory, memory_mask, memory_indices) + x = self.hidden_post_trxl(x) return self.critic(x).flatten() def get_action_and_value(self, x, memory, memory_mask, memory_indices, action=None): @@ -308,6 +314,7 @@ def get_action_and_value(self, x, memory, memory_mask, memory_indices, action=No else: x = self.encoder(x) x, memory = self.transformer(x, memory, memory_mask, memory_indices) + x = self.hidden_post_trxl(x) self.x = x probs = [Categorical(logits=branch(x)) for branch in self.actor_branches] if action is None: