From 516eed747c889edd58d9610cf298a75f0f3978b5 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Fri, 7 Jul 2023 00:35:15 -0400 Subject: [PATCH] Add initial tutorial files for SB3, adapted from jordan's article (#1015) --- docs/tutorials/sb3/index.md | 25 ++++++++++++ docs/tutorials/sb3/pistonball.md | 34 ++++++++++++++++ docs/tutorials/sb3/rps.md | 34 ++++++++++++++++ tutorials/SB3/render_sb3_pistonball.py | 32 +++++++++++++++ tutorials/SB3/render_sb3_rps.py | 27 +++++++++++++ tutorials/SB3/requirements.txt | 3 ++ tutorials/SB3/sb3_pistonball.py | 56 ++++++++++++++++++++++++++ tutorials/SB3/sb3_rps.py | 43 ++++++++++++++++++++ 8 files changed, 254 insertions(+) create mode 100644 docs/tutorials/sb3/index.md create mode 100644 docs/tutorials/sb3/pistonball.md create mode 100644 docs/tutorials/sb3/rps.md create mode 100644 tutorials/SB3/render_sb3_pistonball.py create mode 100644 tutorials/SB3/render_sb3_rps.py create mode 100644 tutorials/SB3/requirements.txt create mode 100644 tutorials/SB3/sb3_pistonball.py create mode 100644 tutorials/SB3/sb3_rps.py diff --git a/docs/tutorials/sb3/index.md b/docs/tutorials/sb3/index.md new file mode 100644 index 000000000..dfd2ebb8f --- /dev/null +++ b/docs/tutorials/sb3/index.md @@ -0,0 +1,25 @@ +--- +title: "Stable-Baselines3" +--- + +# SB3 Tutorial + +These tutorials show you how to use [SB3](https://stable-baselines3.readthedocs.io/en/master/) to train agents in PettingZoo environments. + +* [PPO for Pistonball](/tutorials/sb3/pistonball/): _Train a PPO model in a parallel environment_ + +* [PPO for Rock-Paper-Scissors](/tutorials/sb3/rps/) _Train a PPO model in an AEC environment_ + + +```{figure} https://docs.ray.io/en/latest/_images/rllib-stack.svg + :alt: RLlib stack + :width: 80% +``` + +```{toctree} +:hidden: +:caption: RLlib + +pistonball +holdem +``` diff --git a/docs/tutorials/sb3/pistonball.md b/docs/tutorials/sb3/pistonball.md new file mode 100644 index 000000000..8e86e13e2 --- /dev/null +++ b/docs/tutorials/sb3/pistonball.md @@ -0,0 +1,34 @@ +--- +title: "SB3: PPO for Pistonball (Parallel)" +--- + +# RLlib: PPO for Pistonball + +This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/butterfly/pistonball/) environment ([parallel](https://pettingzoo.farama.org/api/parallel/)). + +After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models. + + +## Environment Setup +To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/requirements.txt + :language: text +``` + +## Code +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). + +### Training the RL agent + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/sb3_pistonball.py + :language: python +``` + +### Watching the trained RL agent play + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/render_sb3_pistonball.py + :language: python +``` diff --git a/docs/tutorials/sb3/rps.md b/docs/tutorials/sb3/rps.md new file mode 100644 index 000000000..fa70d3c55 --- /dev/null +++ b/docs/tutorials/sb3/rps.md @@ -0,0 +1,34 @@ +--- +title: "SB3: PPO for Rock-Paper-Scissors (AEC)" +--- + +# RLlib: PPO for Rock-Paper-Scissors + +This tutorial shows how to train a [Proximal Policy Optimization](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) (PPO) model on the [Pistonball](https://pettingzoo.farama.org/environments/classic/rps/) environment ([AEC](https://pettingzoo.farama.org/api/aec/)). + +After training, run the provided code to watch your trained agent play vs itself. See the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html) for more information about saving and loading models. + + +## Environment Setup +To follow this tutorial, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts. +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/requirements.txt + :language: text +``` + +## Code +The following code should run without any issues. The comments are designed to help you understand how to use PettingZoo with RLLib. If you have any questions, please feel free to ask in the [Discord server](https://discord.gg/nhvKkYa6qX). + +### Training the RL agent + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/sb3_rps.py + :language: python +``` + +### Watching the trained RL agent play + +```{eval-rst} +.. literalinclude:: ../../../tutorials/SB3/render_sb3_rps.py + :language: python +``` diff --git a/tutorials/SB3/render_sb3_pistonball.py b/tutorials/SB3/render_sb3_pistonball.py new file mode 100644 index 000000000..794cf027a --- /dev/null +++ b/tutorials/SB3/render_sb3_pistonball.py @@ -0,0 +1,32 @@ +"""Uses Stable-Baselines3 to view trained agents playing Pistonball. + +Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b + +Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) +""" +import glob +import os + +import supersuit as ss +from stable_baselines3 import PPO + +from pettingzoo.butterfly import pistonball_v6 + +env = pistonball_v6.env(render_mode="human") + +env = ss.color_reduction_v0(env, mode="B") +env = ss.resize_v1(env, x_size=84, y_size=84) +env = ss.frame_stack_v1(env, 3) + +latest_policy = max(glob.glob("rps_*.zip"), key=os.path.getctime) +model = PPO.load(latest_policy) + +env.reset() +for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + act = ( + model.predict(obs, deterministic=True)[0] + if not termination or truncation + else None + ) + env.step(act) diff --git a/tutorials/SB3/render_sb3_rps.py b/tutorials/SB3/render_sb3_rps.py new file mode 100644 index 000000000..c07c15567 --- /dev/null +++ b/tutorials/SB3/render_sb3_rps.py @@ -0,0 +1,27 @@ +"""Uses Stable-Baselines3 to view trained agents playing Rock-Paper-Scissors. + +Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b + +Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) +""" + +import glob +import os + +from stable_baselines3 import PPO + +from pettingzoo.classic import rps_v2 + +env = rps_v2.env(render_mode="human") + +latest_policy = max(glob.glob("rps_*.zip"), key=os.path.getctime) +model = PPO.load(latest_policy) + +env.reset() +for agent in env.agent_iter(): + obs, reward, termination, truncation, info = env.last() + if termination or truncation: + act = None + else: + act = model.predict(obs, deterministic=True)[0] + env.step(act) diff --git a/tutorials/SB3/requirements.txt b/tutorials/SB3/requirements.txt new file mode 100644 index 000000000..72605e8cc --- /dev/null +++ b/tutorials/SB3/requirements.txt @@ -0,0 +1,3 @@ +stable-baselines3 >= 2.0.0 +pettingzoo >= 1.23.1 +supersuit >= 3.8.1 diff --git a/tutorials/SB3/sb3_pistonball.py b/tutorials/SB3/sb3_pistonball.py new file mode 100644 index 000000000..a00f88d08 --- /dev/null +++ b/tutorials/SB3/sb3_pistonball.py @@ -0,0 +1,56 @@ +"""Uses Stable-Baselines3 to train agents to play Pistonball. + +Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b + +Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) +""" +import time + +import supersuit as ss +from stable_baselines3 import PPO +from stable_baselines3.ppo import CnnPolicy + +from pettingzoo.butterfly import pistonball_v6 + +env = pistonball_v6.parallel_env( + n_pistons=20, + time_penalty=-0.1, + continuous=True, + random_drop=True, + random_rotate=True, + ball_mass=0.75, + ball_friction=0.3, + ball_elasticity=1.5, + max_cycles=125, +) + +env = ss.color_reduction_v0(env, mode="B") +env = ss.resize_v1(env, x_size=84, y_size=84) +env = ss.frame_stack_v1(env, 3) + + +env = ss.pettingzoo_env_to_vec_env_v1(env) +env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") + +model = PPO( + CnnPolicy, + env, + verbose=3, + gamma=0.95, + n_steps=256, + ent_coef=0.0905168, + learning_rate=0.00062211, + vf_coef=0.042202, + max_grad_norm=0.9, + gae_lambda=0.99, + n_epochs=5, + clip_range=0.3, + batch_size=256, +) + +model.learn(total_timesteps=2_000_000) + +model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") + + +print("Model has been saved.") diff --git a/tutorials/SB3/sb3_rps.py b/tutorials/SB3/sb3_rps.py new file mode 100644 index 000000000..0439d698a --- /dev/null +++ b/tutorials/SB3/sb3_rps.py @@ -0,0 +1,43 @@ +"""Uses Stable-Baselines3 to train agents to play Rock-Paper-Scissors. + +Adapted from https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b + +Authors: Jordan (https://github.com/jkterry1), Elliot (https://github.com/elliottower) +""" +import time + +import supersuit as ss +from stable_baselines3 import PPO +from stable_baselines3.ppo import MlpPolicy + +from pettingzoo.classic import rps_v2 +from pettingzoo.utils import turn_based_aec_to_parallel + +env = rps_v2.env() +env = turn_based_aec_to_parallel(env) + +env = ss.pettingzoo_env_to_vec_env_v1(env) +env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class="stable_baselines3") + +# TODO: find hyperparameters that make the model actually learn +model = PPO( + MlpPolicy, + env, + verbose=3, + gamma=0.95, + n_steps=256, + ent_coef=0.0905168, + learning_rate=0.00062211, + vf_coef=0.042202, + max_grad_norm=0.9, + gae_lambda=0.99, + n_epochs=5, + clip_range=0.3, + batch_size=256, +) + +model.learn(total_timesteps=2_000_000) + +model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}") + +print("Model has been saved.")