Skip to content

Commit

Permalink
running rpo custom environment
Browse files Browse the repository at this point in the history
  • Loading branch information
Michelle Ho committed May 15, 2024
1 parent 1e1a80f commit 9fa8fc4
Show file tree
Hide file tree
Showing 18 changed files with 36 additions and 1,594 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,6 @@ venv.bak/

# mypy
.mypy_cache/

# Your virtual environment
cleanrl_venv/
15 changes: 11 additions & 4 deletions cleanrl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
import time
from dataclasses import dataclass

import gymnasium as gym
import gym_examples
import gym
# import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
import gym_examples


@dataclass
Expand All @@ -34,7 +37,9 @@ class Args:
"""whether to capture videos of the agent performances (check out `videos` folder)"""

# Algorithm specific arguments
env_id: str = "CartPole-v1"
##TODO: change
# env_id: str = "CartPole-v1"
env_id: str = "RPO_DeTumbling2D-v0"
"""the id of the environment"""
total_timesteps: int = 500000
"""total timesteps of the experiments"""
Expand Down Expand Up @@ -81,10 +86,12 @@ class Args:
def make_env(env_id, idx, capture_video, run_name):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
# env = gym.make(env_id, render_mode="rgb_array")
env = gym.make('gym_examples/RPO_Detumble2DEnv-v0', render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
# env = gym.make(env_id)
env = gym.make('gym_examples/RPO_Detumble2DEnv-v0')
env = gym.wrappers.RecordEpisodeStatistics(env)
return env

Expand Down
13 changes: 9 additions & 4 deletions cleanrl/ppo_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import time
from dataclasses import dataclass

import gymnasium as gym
# import gymnasium as gym
import gym
import gym_examples
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -40,7 +42,8 @@ class Args:
"""the user or org name of the model repository from the Hugging Face Hub"""

# Algorithm specific arguments
env_id: str = "HalfCheetah-v4"
# env_id: str = "HalfCheetah-v4"
env_id: str = "gym_examples/RPO_Detumble2DEnv-v0"
"""the id of the environment"""
total_timesteps: int = 1000000
"""total timesteps of the experiments"""
Expand Down Expand Up @@ -87,10 +90,12 @@ class Args:
def make_env(env_id, idx, capture_video, run_name, gamma):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
# env = gym.make(env_id, render_mode="rgb_array")
env = gym.make('gym_examples/RPO_Detumble2DEnv-v0', render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
# env = gym.make(env_id)
env = gym.make('gym_examples/RPO_Detumble2DEnv-v0')
env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
env = gym.wrappers.RecordEpisodeStatistics(env)
env = gym.wrappers.ClipAction(env)
Expand Down
13 changes: 9 additions & 4 deletions cleanrl/sac_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import time
from dataclasses import dataclass

import gymnasium as gym
# import gymnasium as gym
import gym
import gym_examples
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -35,7 +37,8 @@ class Args:
"""whether to capture videos of the agent performances (check out `videos` folder)"""

# Algorithm specific arguments
env_id: str = "Hopper-v4"
# env_id: str = "Hopper-v4"
env_id: str = "gym_examples/RPO_Detumble2DEnv-v0"
"""the environment id of the task"""
total_timesteps: int = 1000000
"""total timesteps of the experiments"""
Expand Down Expand Up @@ -68,10 +71,12 @@ class Args:
def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
# env = gym.make(env_id, render_mode="rgb_array")
env = gym.make('gym_examples/RPO_Detumble2DEnv-v0', render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
# env = gym.make(env_id)
env = gym.make('gym_examples/RPO_Detumble2DEnv-v0')
env = gym.wrappers.RecordEpisodeStatistics(env)
env.action_space.seed(seed)
return env
Expand Down
Empty file removed env_original/__init__.py
Empty file.
Empty file.
67 changes: 0 additions & 67 deletions env_original/rpo_detubmling/dynamics/dyn_misc.py

This file was deleted.

136 changes: 0 additions & 136 deletions env_original/rpo_detubmling/dynamics/dynamics_6dof.py

This file was deleted.

Loading

0 comments on commit 9fa8fc4

Please sign in to comment.