diff --git a/cleanrl/dqn.py b/cleanrl/dqn.py index b2b8e8555..d71a3bddb 100644 --- a/cleanrl/dqn.py +++ b/cleanrl/dqn.py @@ -53,7 +53,7 @@ class Args: """the replay memory buffer size""" gamma: float = 0.99 """the discount factor gamma""" - tau: float = 1. + tau: float = 1.0 """the target network update rate""" target_network_frequency: int = 500 """the timesteps it takes to update the target network""" diff --git a/cleanrl/dqn_atari.py b/cleanrl/dqn_atari.py index 832d027aa..ba32b197d 100644 --- a/cleanrl/dqn_atari.py +++ b/cleanrl/dqn_atari.py @@ -60,7 +60,7 @@ class Args: """the replay memory buffer size""" gamma: float = 0.99 """the discount factor gamma""" - tau: float = 1. + tau: float = 1.0 """the target network update rate""" target_network_frequency: int = 1000 """the timesteps it takes to update the target network""" diff --git a/cleanrl/ppg_procgen.py b/cleanrl/ppg_procgen.py index 46bcca78f..ad9dc58fb 100644 --- a/cleanrl/ppg_procgen.py +++ b/cleanrl/ppg_procgen.py @@ -85,8 +85,6 @@ class Args: n_aux_grad_accum: int = 1 """the number of gradient accumulation in mini batch""" - - # to be filled in runtime batch_size: int = 0 """the batch size (computed in runtime)""" @@ -283,7 +281,6 @@ def get_pi(self, x): start_time = time.time() next_obs = torch.Tensor(envs.reset()).to(device) next_done = torch.zeros(args.num_envs).to(device) - for phase in range(1, args.num_phases + 1): @@ -397,7 +394,7 @@ def get_pi(self, x): optimizer.step() if args.target_kl is not None and approx_kl > args.target_kl: - break + break y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() var_y = np.var(y_true) diff --git a/cleanrl/qdagger_dqn_atari_jax_impalacnn.py b/cleanrl/qdagger_dqn_atari_jax_impalacnn.py index b199b45cc..143ec66ec 100644 --- a/cleanrl/qdagger_dqn_atari_jax_impalacnn.py +++ b/cleanrl/qdagger_dqn_atari_jax_impalacnn.py @@ -1,10 +1,8 @@ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/qdagger/#qdagger_dqn_atari_jax_impalacnnpy -import argparse import os import random import time from collections import deque -from distutils.util import strtobool from typing import Sequence os.environ[