Skip to content

Commit

Permalink
Add initial tutorial files for SB3, adapted from jordan's article (#1015
Browse files Browse the repository at this point in the history
)
  • Loading branch information
elliottower authored Jul 7, 2023
1 parent 7714f51 commit 516eed7
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 0 deletions.
25 changes: 25 additions & 0 deletions docs/tutorials/sb3/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
---
title: "Stable-Baselines3"
---

# SB3 Tutorial

These tutorials show you how to use [SB3](https://stable-baselines3.readthedocs.io/en/master/) to train agents in PettingZoo environments.

* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train a PPO model in a parallel environment_

* [PPO for Rock-Paper-Scissors](/tutorials/sb3/rps/) _Train a PPO model in an AEC environment_


```{figure} https://docs.ray.io/en/latest/_images/rllib-stack.svg
:alt: RLlib stack
:width: 80%
```

```{toctree}
:hidden:
:caption: RLlib
pistonball
holdem
```
34 changes: 34 additions & 0 deletions docs/tutorials/sb3/pistonball.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
---
title: "SB3: PPO for Pistonball (Parallel)"
---

# RLlib: PPO for Pistonball

This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/butterfly/pistonball/) environment ([parallel](https://pettingzoo.farama.org/api/parallel/)).

After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models.


## Environment Setup
To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts.
```{eval-rst}
.. literalinclude:: ../../../tutorials/SB3/requirements.txt
:language: text
```

## Code
The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX).

### Training the RL agent

```{eval-rst}
.. literalinclude:: ../../../tutorials/SB3/sb3_pistonball.py
:language: python
```

### Watching the trained RL agent play

```{eval-rst}
.. literalinclude:: ../../../tutorials/SB3/render_sb3_pistonball.py
:language: python
```
34 changes: 34 additions & 0 deletions docs/tutorials/sb3/rps.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
---
title: "SB3: PPO for Rock-Paper-Scissors (AEC)"
---

# RLlib: PPO for Rock-Paper-Scissors

This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/classic/rps/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)).

After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models.


## Environment Setup
To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts.
```{eval-rst}
.. literalinclude:: ../../../tutorials/SB3/requirements.txt
:language: text
```

## Code
The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX).

### Training the RL agent

```{eval-rst}
.. literalinclude:: ../../../tutorials/SB3/sb3_rps.py
:language: python
```

### Watching the trained RL agent play

```{eval-rst}
.. literalinclude:: ../../../tutorials/SB3/render_sb3_rps.py
:language: python
```
32 changes: 32 additions & 0 deletions tutorials/SB3/render_sb3_pistonball.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Uses Stable-Baselines3 to view trained agents playing Pistonball.
Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b
Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower)
"""
import glob
import os

import supersuit as ss
from stable_baselines3 import PPO

from pettingzoo.butterfly import pistonball_v6

env = pistonball_v6.env(render_mode="human")

env = ss.color_reduction_v0(env, mode="B")
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.frame_stack_v1(env, 3)

latest_policy = max(glob.glob("rps_*.zip"), key=os.path.getctime)
model = PPO.load(latest_policy)

env.reset()
for agent in env.agent_iter():
obs, reward, termination, truncation, info = env.last()
act = (
model.predict(obs, deterministic=True)[0]
if not termination or truncation
else None
)
env.step(act)
27 changes: 27 additions & 0 deletions tutorials/SB3/render_sb3_rps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Uses Stable-Baselines3 to view trained agents playing Rock-Paper-Scissors.
Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b
Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower)
"""

import glob
import os

from stable_baselines3 import PPO

from pettingzoo.classic import rps_v2

env = rps_v2.env(render_mode="human")

latest_policy = max(glob.glob("rps_*.zip"), key=os.path.getctime)
model = PPO.load(latest_policy)

env.reset()
for agent in env.agent_iter():
obs, reward, termination, truncation, info = env.last()
if termination or truncation:
act = None
else:
act = model.predict(obs, deterministic=True)[0]
env.step(act)
3 changes: 3 additions & 0 deletions tutorials/SB3/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
stable-baselines3 >= 2.0.0
pettingzoo >= 1.23.1
supersuit >= 3.8.1
56 changes: 56 additions & 0 deletions tutorials/SB3/sb3_pistonball.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Uses Stable-Baselines3 to train agents to play Pistonball.
Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b
Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower)
"""
import time

import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import CnnPolicy

from pettingzoo.butterfly import pistonball_v6

env = pistonball_v6.parallel_env(
n_pistons=20,
time_penalty=-0.1,
continuous=True,
random_drop=True,
random_rotate=True,
ball_mass=0.75,
ball_friction=0.3,
ball_elasticity=1.5,
max_cycles=125,
)

env = ss.color_reduction_v0(env, mode="B")
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.frame_stack_v1(env, 3)


env = ss.pettingzoo_env_to_vec_env_v1(env)
env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3")

model = PPO(
CnnPolicy,
env,
verbose=3,
gamma=0.95,
n_steps=256,
ent_coef=0.0905168,
learning_rate=0.00062211,
vf_coef=0.042202,
max_grad_norm=0.9,
gae_lambda=0.99,
n_epochs=5,
clip_range=0.3,
batch_size=256,
)

model.learn(total_timesteps=2_000_000)

model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")


print("Model has been saved.")
43 changes: 43 additions & 0 deletions tutorials/SB3/sb3_rps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Uses Stable-Baselines3 to train agents to play Rock-Paper-Scissors.
Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b
Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower)
"""
import time

import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

from pettingzoo.classic import rps_v2
from pettingzoo.utils import turn_based_aec_to_parallel

env = rps_v2.env()
env = turn_based_aec_to_parallel(env)

env = ss.pettingzoo_env_to_vec_env_v1(env)
env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3")

# TODO: find hyperparameters that make the model actually learn
model = PPO(
MlpPolicy,
env,
verbose=3,
gamma=0.95,
n_steps=256,
ent_coef=0.0905168,
learning_rate=0.00062211,
vf_coef=0.042202,
max_grad_norm=0.9,
gae_lambda=0.99,
n_epochs=5,
clip_range=0.3,
batch_size=256,
)

model.learn(total_timesteps=2_000_000)

model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")

print("Model has been saved.")

0 comments on commit 516eed7

Please sign in to comment.