-
-
Notifications
You must be signed in to change notification settings - Fork 410
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial tutorial files for SB3, adapted from jordan's article (#1015
- Loading branch information
1 parent
7714f51
commit 516eed7
Showing
8 changed files
with
254 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
--- | ||
title: "Stable-Baselines3" | ||
--- | ||
|
||
# SB3 Tutorial | ||
|
||
These tutorials show you how to use [SB3](https://stable-baselines3.readthedocs.io/en/master/) to train agents in PettingZoo environments. | ||
|
||
* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train a PPO model in a parallel environment_ | ||
|
||
* [PPO for Rock-Paper-Scissors](/tutorials/sb3/rps/) _Train a PPO model in an AEC environment_ | ||
|
||
|
||
```{figure} https://docs.ray.io/en/latest/_images/rllib-stack.svg | ||
:alt: RLlib stack | ||
:width: 80% | ||
``` | ||
|
||
```{toctree} | ||
:hidden: | ||
:caption: RLlib | ||
pistonball | ||
holdem | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
--- | ||
title: "SB3: PPO for Pistonball (Parallel)" | ||
--- | ||
|
||
# RLlib: PPO for Pistonball | ||
|
||
This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/butterfly/pistonball/) environment ([parallel](https://pettingzoo.farama.org/api/parallel/)). | ||
|
||
After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models. | ||
|
||
|
||
## Environment Setup | ||
To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. | ||
```{eval-rst} | ||
.. literalinclude:: ../../../tutorials/SB3/requirements.txt | ||
:language: text | ||
``` | ||
|
||
## Code | ||
The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). | ||
|
||
### Training the RL agent | ||
|
||
```{eval-rst} | ||
.. literalinclude:: ../../../tutorials/SB3/sb3_pistonball.py | ||
:language: python | ||
``` | ||
|
||
### Watching the trained RL agent play | ||
|
||
```{eval-rst} | ||
.. literalinclude:: ../../../tutorials/SB3/render_sb3_pistonball.py | ||
:language: python | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
--- | ||
title: "SB3: PPO for Rock-Paper-Scissors (AEC)" | ||
--- | ||
|
||
# RLlib: PPO for Rock-Paper-Scissors | ||
|
||
This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/classic/rps/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). | ||
|
||
After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models. | ||
|
||
|
||
## Environment Setup | ||
To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. | ||
```{eval-rst} | ||
.. literalinclude:: ../../../tutorials/SB3/requirements.txt | ||
:language: text | ||
``` | ||
|
||
## Code | ||
The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). | ||
|
||
### Training the RL agent | ||
|
||
```{eval-rst} | ||
.. literalinclude:: ../../../tutorials/SB3/sb3_rps.py | ||
:language: python | ||
``` | ||
|
||
### Watching the trained RL agent play | ||
|
||
```{eval-rst} | ||
.. literalinclude:: ../../../tutorials/SB3/render_sb3_rps.py | ||
:language: python | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"""Uses Stable-Baselines3 to view trained agents playing Pistonball. | ||
Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b | ||
Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) | ||
""" | ||
import glob | ||
import os | ||
|
||
import supersuit as ss | ||
from stable_baselines3 import PPO | ||
|
||
from pettingzoo.butterfly import pistonball_v6 | ||
|
||
env = pistonball_v6.env(render_mode="human") | ||
|
||
env = ss.color_reduction_v0(env, mode="B") | ||
env = ss.resize_v1(env, x_size=84, y_size=84) | ||
env = ss.frame_stack_v1(env, 3) | ||
|
||
latest_policy = max(glob.glob("rps_*.zip"), key=os.path.getctime) | ||
model = PPO.load(latest_policy) | ||
|
||
env.reset() | ||
for agent in env.agent_iter(): | ||
obs, reward, termination, truncation, info = env.last() | ||
act = ( | ||
model.predict(obs, deterministic=True)[0] | ||
if not termination or truncation | ||
else None | ||
) | ||
env.step(act) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
"""Uses Stable-Baselines3 to view trained agents playing Rock-Paper-Scissors. | ||
Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b | ||
Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) | ||
""" | ||
|
||
import glob | ||
import os | ||
|
||
from stable_baselines3 import PPO | ||
|
||
from pettingzoo.classic import rps_v2 | ||
|
||
env = rps_v2.env(render_mode="human") | ||
|
||
latest_policy = max(glob.glob("rps_*.zip"), key=os.path.getctime) | ||
model = PPO.load(latest_policy) | ||
|
||
env.reset() | ||
for agent in env.agent_iter(): | ||
obs, reward, termination, truncation, info = env.last() | ||
if termination or truncation: | ||
act = None | ||
else: | ||
act = model.predict(obs, deterministic=True)[0] | ||
env.step(act) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
stable-baselines3 >= 2.0.0 | ||
pettingzoo >= 1.23.1 | ||
supersuit >= 3.8.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Uses Stable-Baselines3 to train agents to play Pistonball. | ||
Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b | ||
Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) | ||
""" | ||
import time | ||
|
||
import supersuit as ss | ||
from stable_baselines3 import PPO | ||
from stable_baselines3.ppo import CnnPolicy | ||
|
||
from pettingzoo.butterfly import pistonball_v6 | ||
|
||
env = pistonball_v6.parallel_env( | ||
n_pistons=20, | ||
time_penalty=-0.1, | ||
continuous=True, | ||
random_drop=True, | ||
random_rotate=True, | ||
ball_mass=0.75, | ||
ball_friction=0.3, | ||
ball_elasticity=1.5, | ||
max_cycles=125, | ||
) | ||
|
||
env = ss.color_reduction_v0(env, mode="B") | ||
env = ss.resize_v1(env, x_size=84, y_size=84) | ||
env = ss.frame_stack_v1(env, 3) | ||
|
||
|
||
env = ss.pettingzoo_env_to_vec_env_v1(env) | ||
env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") | ||
|
||
model = PPO( | ||
CnnPolicy, | ||
env, | ||
verbose=3, | ||
gamma=0.95, | ||
n_steps=256, | ||
ent_coef=0.0905168, | ||
learning_rate=0.00062211, | ||
vf_coef=0.042202, | ||
max_grad_norm=0.9, | ||
gae_lambda=0.99, | ||
n_epochs=5, | ||
clip_range=0.3, | ||
batch_size=256, | ||
) | ||
|
||
model.learn(total_timesteps=2_000_000) | ||
|
||
model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") | ||
|
||
|
||
print("Model has been saved.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
"""Uses Stable-Baselines3 to train agents to play Rock-Paper-Scissors. | ||
Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b | ||
Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) | ||
""" | ||
import time | ||
|
||
import supersuit as ss | ||
from stable_baselines3 import PPO | ||
from stable_baselines3.ppo import MlpPolicy | ||
|
||
from pettingzoo.classic import rps_v2 | ||
from pettingzoo.utils import turn_based_aec_to_parallel | ||
|
||
env = rps_v2.env() | ||
env = turn_based_aec_to_parallel(env) | ||
|
||
env = ss.pettingzoo_env_to_vec_env_v1(env) | ||
env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") | ||
|
||
# TODO: find hyperparameters that make the model actually learn | ||
model = PPO( | ||
MlpPolicy, | ||
env, | ||
verbose=3, | ||
gamma=0.95, | ||
n_steps=256, | ||
ent_coef=0.0905168, | ||
learning_rate=0.00062211, | ||
vf_coef=0.042202, | ||
max_grad_norm=0.9, | ||
gae_lambda=0.99, | ||
n_epochs=5, | ||
clip_range=0.3, | ||
batch_size=256, | ||
) | ||
|
||
model.learn(total_timesteps=2_000_000) | ||
|
||
model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") | ||
|
||
print("Model has been saved.") |