Surjection layers for density estimation with normalizing flows
Surjectors is a light-weight library for density estimation using inference and generative surjective normalizing flows, i.e., flows can that reduce or increase dimensionality. Surjectors makes use of
- Haiku`s module system for neural networks,
- Distrax for probability distributions and some base bijectors,
- Optax for gradient-based optimization,
- JAX for autodiff and XLA computation.
You can, for instance, construct a simple normalizing flow like this:
import distrax
import haiku as hk
from jax import numpy as jnp, random as jr
from surjectors import Slice, LULinear, Chain
from surjectors import TransformedDistribution
from surjectors.nn import make_mlp
def decoder_fn(n_dim):
def _fn(z):
params = make_mlp([32, 32, n_dim * 2])(z)
means, log_scales = jnp.split(params, 2, -1)
return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
return _fn
@hk.without_apply_rng
@hk.transform
def flow(x):
base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(5), jnp.ones(5)), 1
)
transform = Chain([Slice(5, decoder_fn(5)), LULinear(5)])
pushforward = TransformedDistribution(base_distribution, transform)
return pushforward.log_prob(x)
x = jr.normal(jr.PRNGKey(1), (1, 10))
params = flow.init(jr.PRNGKey(2), x)
lp = flow.apply(params, x)
More self-contained examples can be found in examples.
Documentation can be found here.
Make sure to have a working JAX
installation. Depending whether you want to use CPU/GPU/TPU,
please follow these instructions.
To install the package from PyPI, call:
pip install surjectors
To install the latest GitHub , just call the following on the command line:
pip install git+https://github.com/dirmeier/surjectors@<RELEASE>
Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.
In order to contribute:
- Clone
Surjectors
and installhatch
viapip install hatch
, - create a new branch locally
git checkout -b feature/my-new-feature
orgit checkout -b issue/fixes-bug
, - implement your contribution and ideally a test case,
- test it by calling
hatch run test
on the (Unix) command line, - submit a PR 🙂
If you find our work relevant to your research, please consider citing:
@article{dirmeier2024surjectors,
author = {Simon Dirmeier},
title = {Surjectors: surjection layers for density estimation with normalizing flows},
year = {2024},
journal = {Journal of Open Source Software},
publisher = {The Open Journal},
volume = {9},
number = {94},
pages = {6188},
doi = {10.21105/joss.06188}
}
Simon Dirmeier sfyrbnd @ pm me