Skip to content

Commit

Permalink
push
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Oct 16, 2023
1 parent cd4851e commit b97d54f
Show file tree
Hide file tree
Showing 8 changed files with 415 additions and 460 deletions.
125 changes: 60 additions & 65 deletions cleanrl/c51.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,77 @@
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/c51/#c51py
import argparse
import os
import random
import time
from distutils.util import strtobool
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter


def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
track: bool = False
"""if toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "cleanRL"
"""the wandb's project name"""
wandb_entity: str = None
"""the entity (team) of wandb's project"""
capture_video: bool = False
"""whether to capture videos of the agent performances (check out `videos` folder)"""
save_model: bool = False
"""whether to save model into the `runs/{run_name}` folder"""
upload_model: bool = False
"""whether to upload the saved model to huggingface"""
hf_entity: str = ""
"""the user or org name of the model repository from the Hugging Face Hub"""

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="CartPole-v1",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=500000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--n-atoms", type=int, default=101,
help="the number of atoms")
parser.add_argument("--v-min", type=float, default=-100,
help="the return lower bound")
parser.add_argument("--v-max", type=float, default=100,
help="the return upper bound")
parser.add_argument("--buffer-size", type=int, default=10000,
help="the replay memory buffer size")
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--target-network-frequency", type=int, default=500,
help="the timesteps it takes to update the target network")
parser.add_argument("--batch-size", type=int, default=128,
help="the batch size of sample from the reply memory")
parser.add_argument("--start-e", type=float, default=1,
help="the starting epsilon for exploration")
parser.add_argument("--end-e", type=float, default=0.05,
help="the ending epsilon for exploration")
parser.add_argument("--exploration-fraction", type=float, default=0.5,
help="the fraction of `total-timesteps` it takes from start-e to go end-e")
parser.add_argument("--learning-starts", type=int, default=10000,
help="timestep to start learning")
parser.add_argument("--train-frequency", type=int, default=10,
help="the frequency of training")
args = parser.parse_args()
# fmt: on
assert args.num_envs == 1, "vectorized envs are not supported at the moment"

return args
env_id: str = "CartPole-v1"
"""the id of the environment"""
total_timesteps: int = 500000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 1
"""the number of parallel game environments"""
n_atoms: int = 101
"""the number of atoms"""
v_min: float = -100
"""the return lower bound"""
v_max: float = 100
"""the return upper bound"""
buffer_size: int = 10000
"""the replay memory buffer size"""
gamma: float = 0.99
"""the discount factor gamma"""
target_network_frequency: int = 500
"""the timesteps it takes to update the target network"""
batch_size: int = 128
"""the batch size of sample from the reply memory"""
start_e: float = 1
"""the starting epsilon for exploration"""
end_e: float = 0.05
"""the ending epsilon for exploration"""
exploration_fraction: float = 0.5
"""the fraction of `total-timesteps` it takes from start-e to go end-e"""
learning_starts: int = 10000
"""timestep to start learning"""
train_frequency: int = 10
"""the frequency of training"""


def make_env(env_id, seed, idx, capture_video, run_name):
Expand Down Expand Up @@ -136,7 +130,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
poetry run pip install "stable_baselines3==2.0.0a1"
"""
)
args = parse_args()
args = tyro.cli(Args)
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb
Expand Down
125 changes: 60 additions & 65 deletions cleanrl/c51_atari.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/c51/#c51_ataripy
import argparse
import os
import random
import time
from distutils.util import strtobool
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from stable_baselines3.common.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
Expand All @@ -21,70 +21,64 @@
from torch.utils.tensorboard import SummaryWriter


def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
track: bool = False
"""if toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "cleanRL"
"""the wandb's project name"""
wandb_entity: str = None
"""the entity (team) of wandb's project"""
capture_video: bool = False
"""whether to capture videos of the agent performances (check out `videos` folder)"""
save_model: bool = False
"""whether to save model into the `runs/{run_name}` folder"""
upload_model: bool = False
"""whether to upload the saved model to huggingface"""
hf_entity: str = ""
"""the user or org name of the model repository from the Hugging Face Hub"""

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=10000000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--n-atoms", type=int, default=51,
help="the number of atoms")
parser.add_argument("--v-min", type=float, default=-10,
help="the return lower bound")
parser.add_argument("--v-max", type=float, default=10,
help="the return upper bound")
parser.add_argument("--buffer-size", type=int, default=1000000,
help="the replay memory buffer size")
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--target-network-frequency", type=int, default=10000,
help="the timesteps it takes to update the target network")
parser.add_argument("--batch-size", type=int, default=32,
help="the batch size of sample from the reply memory")
parser.add_argument("--start-e", type=float, default=1,
help="the starting epsilon for exploration")
parser.add_argument("--end-e", type=float, default=0.01,
help="the ending epsilon for exploration")
parser.add_argument("--exploration-fraction", type=float, default=0.10,
help="the fraction of `total-timesteps` it takes from start-e to go end-e")
parser.add_argument("--learning-starts", type=int, default=80000,
help="timestep to start learning")
parser.add_argument("--train-frequency", type=int, default=4,
help="the frequency of training")
args = parser.parse_args()
# fmt: on
assert args.num_envs == 1, "vectorized envs are not supported at the moment"

return args
env_id: str = "BreakoutNoFrameskip-v4"
"""the id of the environment"""
total_timesteps: int = 10000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 1
"""the number of parallel game environments"""
n_atoms: int = 51
"""the number of atoms"""
v_min: float = -10
"""the return lower bound"""
v_max: float = 10
"""the return upper bound"""
buffer_size: int = 1000000
"""the replay memory buffer size"""
gamma: float = 0.99
"""the discount factor gamma"""
target_network_frequency: int = 10000
"""the timesteps it takes to update the target network"""
batch_size: int = 32
"""the batch size of sample from the reply memory"""
start_e: float = 1
"""the starting epsilon for exploration"""
end_e: float = 0.01
"""the ending epsilon for exploration"""
exploration_fraction: float = 0.10
"""the fraction of `total-timesteps` it takes from start-e to go end-e"""
learning_starts: int = 80000
"""timestep to start learning"""
train_frequency: int = 4
"""the frequency of training"""


def make_env(env_id, seed, idx, capture_video, run_name):
Expand Down Expand Up @@ -158,7 +152,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
"""
)
args = parse_args()
args = tyro.cli(Args)
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb
Expand Down
Loading

0 comments on commit b97d54f

Please sign in to comment.