This repository contains a collection of goal-conditioned reinforcement learning algorithms. It is compatible with the latest Gymnasium API and uses very recent version of jax, flax and optax. We support multiprocessing via mpi4jax like the deprecated OpenAI baselines.
- Deep Deterministic Policy Gradient (DDPG paper link)
- Soft Actor-Critic (SAC paper link)
- DroQ (paper link)
All algorithms make use of Hindsight Experience Replay (HER paper link)
git clone https://github.com/frankroeder/goal_conditioned_rl.git
- pip users:
pip install -r requirements.txt
- conda users:
conda create --file= conda_env.yaml
- libraries:
apt install libopenmpi-dev
https://github.com/google/jax#installation To install on a machine with an Nvidia GPU, run
# install packages
pip install -r requirements.txt
# remove jaxlib and install cuda version of necessary
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# SAC
python train.py n_epochs=10 agent=sac env_name=FetchPush-v2 hindsight=her agent.critic.dropout=0.0
# DDPG
python train.py n_epochs=10 agent=ddpg env_name=FetchPush-v2 hindsight=her
# DroQ
python train.py n_epochs=10 agent=sac env_name=FetchPush-v2 hindsight=her agent.critic.dropout=0.01
mpirun -np 4 python -u train.py n_epochs=10 agent=sac env_name=FetchPush-v2 hindsight=her
python demo.py --demo_path <path to the trial folder>
# or
python demo.py --wandb_url <wandb trial url>
... more results will follow