JAX-based environments and RL baselines for studying goal misgeneralisation.
Note
jaxgmg
is a work in progress. This is a pre-1.0 research codebase. Not
everything is finished, tested, documented, or stable.
See also the roadmap.
Install the latest main version from GitHub:
pip install git+https://github.com/matomatical/jaxgmg.git
Install from a local clone:
git clone git@github.com:matomatical/jaxgmg.git
cd jaxgmg
pip install -e .
After installing run the following:
jaxgmg --help
You can try the various subcommands to see demonstrations of the library's functionality. For example:
- To play with an interactive demonstration of the environments, try
jaxgmg play ENV_NAME
(seejaxgmg play --help
for options). - To procedurally generate mazes , try
jaxgmg mazegen LAYOUT
(seejaxgmg mazegen --help
for options).
Note: Most of the demos display colour images to the terminal using ANSI control codes, which may not work in some environments (e.g. on Windows?).
The following environments are provided.
Animations in this table produced with jaxgmg play ENV_NAME
. The actions
are chosen by a human.
Procedural level generation and environment update and rendering are fully JAX-accelerable. The speed of level generation depends on the generator configuration and whether we need to cache level solutions, as explored in the next sections. However, the speed of the environment's step method depends only on the size of the level and the number of entities (such as keys or chests) and not the procedurally generated layout. The following tables report speeds from rollouts of a random policy (steps per second, including rendering a Boolean or RGB observation) for each environment with different level sizes and different hardware.
On a M2 Macbook Air (without using Metal):
Environment | 13x13 bool | 25x25 bool | 13x13 rgb | 25x25 rgb |
---|---|---|---|---|
Cheese in the Corner | 3.20M | 968K | 50.2K | 14.0K |
Cheese on a Dish | 6.90M | 2.78M | 47.3K | 13.3K |
Follow Me | 2.10M | 620K | 44.3K | 12.4K |
Keys and Chests | 1.77M | 580K | 47.2K | 13.1K |
Lava Land | 7.29M | 2.94M | 46.7K | 13.4K |
Monster World | 597K | 315K | 46.7K | 13.9K |
TODO: Test on a GPU.
Rates collected with jaxgmg speedtest envstep-ENVIRONMENT
, batch size 32,
128 steps per trial, showing the mean (and standard deviation in parentheses)
as calculated after discarding the first trial since that includes the
compilation step.
Each environment supports a wide distribution of 'levels', and the library includes easily configurable tools for procedural level generation.
At the core of these level generators is a suite of configurable procedural maze generation methods, some outputs of which are depicted in the below mural. There are currently five maze generation methods, any of which can be paired with any of the above environment.
- Tree mazes: acyclic mazes based on spanning trees of a grid graph, generated using Kruskal's random spanning tree algorithm (Mural row 1).
- Edge mazes: a grid maze where each edge is independently determined to be traversable with configurable probability (Mural rows 2 and 3, edge probabilities 75% and 85% respectively).
- Block mazes: wherein each cell is determined to have a block/wall independently with configurable probability (Mural row 4, block probability 25%).
- Noise mazes: based on thresholding Perlin noise (or associated fractal noise) with a gradient grid of configurable cell size (Rows 5, 6, and 7, respectively with gradient cell size 2, 3, and 8; row 8 depicts fractal noise with three octaves which just means we superimposed Perlin noise with cell sizes 4 and 2 onto the base noise with cell size 8).
- Open mazes: an empty maze with no obstacles and no procedural variation (row 8). This case is trivial, nevertheless it is useful in some cases such as testing RL algorithms and as a starting point for RL algorithms that build their own maze layouts.
Mural produced with jaxgmg mazegen mural --num_cols=16
Given a generator configuration and a target maze size, these maze generation algorithms are fully acceleratable with JAX. The following tables report rates of maze generation (mazes per second) for different configurations and hardware.
On a M2 Macbook Air (without using Metal):
Generator | 13x13 mazes | 25x25 mazes | 49x49 mazes |
---|---|---|---|
Tree | 119K (430) | 22.0K (41) | 4.56K (16) |
Tree (alt. Kruskal impl.) | 90.4K (300) | 13.6K (25) | 860 (1.5) |
Edges | 1.55M (17K) | 389K (13K) | 108K (4.7K) |
Blocks | 963K (12K) | 221K (11K) | 106K (1.2K) |
Noise (2x2 cells) | 651K (5.0K) | 141K (970) | 37.4K (130) |
Noise (8x8 cells) | 1.08M (10K) | 366K (4.6K) | 82.9K (920) |
Noise (8x8 cells, 3 octaves) | 244K (1.5K) | 78K (160) | 19.9K (61) |
Open | 13.8M (592K) | 7.22M (340K) | 2.10M (720K) |
TODO: test on a GPU.
Rates collected with jaxgmg speedtest mazegen-METHOD
, batch size 32, 512
iterations per trial, mean (and standard deviation) calculated after
discarding the first trial since that includes the compilation step.
Some environments (Follow Me, Monster World) don't just require a maze layout, they have transition dynamics ('NPCs') that depend on optimal navigation within the arbitrary generated layout. This library's solution is to compute all-pairs shortest path information during level generation and cache this navigation information as part of the level struct for quick use during rollouts. Caching the 'solution' to a maze layout at level generation time should be faster overall than running a shortest path algorithm during each step, but we still want level generation to be accelerated, so we need a JAX-accelerated all-pairs shortest path algorithm.
The module jaxgmg.procgen.maze_solving
provides a JAX-accelerated all-pairs
shortest path algorithm, with methods returning a tensor that encodes the
distance or optimal direction to move from any source node to any destination
node. The following is a visualisation of the result for a tree maze (the
algorithms work for arbitrary maze layouts).
How to read: The position in the 'macromaze' represents the source and the position in the 'micromaze' (the small version of the maze at that position) represents the destination. The colour indicates the shortest path distance (in the left maze) or the optimal direction (in the right maze) to reach the destination position from the source position. For the shortest path distances (left) the colour indicates the distance: , , ..., , ..., . For the optimal directions (right) the colour indicates the direction: , , , , .
Visualisation generated with jaxgmg mazesoln distance-direction
.
The maze solution algorithms are fully accelerated with JAX. The following tables report rates of maze solving (mazes per second) for different configurations and hardware.
Note: the time taken to generate the mazes is not included, and not all generators are tested because the solution methods run similar operations independent of the maze contents.
On a M2 Macbook Air (without using Metal):
Type of solution | Gen. | 13x13 mazes | 25x25 mazes | 49x49 mazes |
---|---|---|---|---|
Distance | Tree | 1.35K (21) | TODO | TODO |
Distance | Edges | 1.34K (17) | ||
Directional distances | Edges | 1.16K (38) | ||
Direction (uldr) | Edges | 1.06K (7.4) | ||
Direction (uldr+stay) | Edges | 1.01K (3.1) |
TODO: test on a GPU.
Rates collected with jaxgmg speedtest mazesoln-METHOD
, batch size 32,
128 iterations per trial, mean (and standard deviation) calculated after
discarding the first trial since that includes the compilation step.
Some environments also support optimal or heuristic policies (using these same maze solving methods as a foundation).
TODO: Document.
The repository includes a PPO implementation that runs in a subset of environments (so far, only "Cheese in the Corner", with support for more environments a work in progress).
Most environment and training hyperparameters can be set from the command line. To see the API for running an experiment, run:
jaxgmg train corner --help
Include the flag --wandb-log
to log results to wandb, etc.
TODO: Include some results here
Procedural generation methods:
- Kruskal's algorithm
- Random block mazes
- Perlin noise and fractal noise
Environments (JAX accelerated):
- Cheese in the corner
- Keys and chests
- Monster world
- Cheese on a dish
- Lava land
- Follow the leader (simplified 'cultural transmission')
Environment features:
- Boolean rendering
- 8x8 RGB rendering
- Rendering in other resolutions (1x1. 3x3, 4x4)
RL baselines:
- Train PPO agents in Cheese in the Corner (symbolic environment)
- Qualitative and quantitative demonstration of goal misgeneralisation
Packaging:
- Create this repository
- Format project as an installable Python package
- CLI easily demonstrating core features
- Animation/images of core environments, procedural generation methods
- Speedtests of generation methods, environments
- Speedtests of baselines
- Document speedtests and RL experiments in a report (arXiv)
More procedural generation methods (see notes here):
- Simple room placement?
- BSP?
- Tunnellers?
- Cellular automata?
- Drunkard's walk?
More environments:
- Forest recovery
- Coin at the end (simplified 'coinrun'-style platformer)
- Survivor ('crafter'-style mining/farming grid world)
- Dungeon (a simple roguelike)
- More games inspired by Procgen
More environment features:
- Procgen-style variable-size mazes
- Procgen-style sprite and background diversity
- Partially observable versions
- Gymnax API wrappers and registration
More RL baselines:
- Train PPO agents in other environments (symbolic and pixels)
- Train DQN agents in all environments (symbolic and pixels)