The RAS (Referenced-based Adversarial Sampling) algorithm is proposed to enable adversarial learning applicable to general unnormalized distribution sampling, with demonstrations on constrained domain sampling and soft Q-learning. This repository contains source code to reproduce the results presented in the paper Adversarial Learning of a Sampler Based on an Unnormalized Distribution (AISTATS 2019):
@inproceedings{Li_RAS_2019_AISTATS,
title={Adversarial Learning of a Sampler Based on an Unnormalized Distribution},
author={Chunyuan Li, Ke Bai, Jianqiao Li, Guoyin Wang, Changyou Chen, Lawrence Carin},
booktitle={AISTATS},
year={2019}
}
Learning a neural sampler q to approximate the target distribution p, where only the latter's unnormalized form u or empirical samples p' is available, respectively.
RAS | GAN | |
---|---|---|
Illustration | ||
Method | We propose the “reference” p_r to bridge neural samples q and unnormalized form u, making the evaluations of both F_1 and F_2 terms feasible. | Directly matching neural samples q to empirical samples p' |
Setup | Learning from unnormalized form u | Learning from empirical samples p' |
Generator | ||
Discriminator | q vs p_r | q vs p' |
Application to reinforcement learning | Learning to take optimal actions based on Q-functions | GAIL: Learning to take optimal actions based on expert sample trajectories (a.k.a. Imitation learning) |
- In many applications (e.g. Soft Q-learining), only u is known, from which we are inerested in drawing samples efficiently
- The choice of p_r has an effect on learning; It should be carefully chosen.
There are three steps to use this codebase to reproduce the results in the paper.
This code is based on Python 2.7, with the main dependencies being TensorFlow==1.7.0. Additional dependencies for running experiments are: numpy
, cPickle
, scipy
, math
, gensim
.
We consider the following environments: Hopper
, Half-cheetah
, Ant
, Walker
, Swimmer
and Humanoid
. All soft q-learning code is at sql
:
To run:
python mujoco_all_sql.py --env Hopper
It takes the following options (among others) as arguments:
--env
It specifies the MuJoCo/rllab environment; defaultHopper
.--log_dir
Address to save the training log.- For other arguements, please refer to the github repo soft-q-learning
Other related hyper-parameters setting are located in sql/examples/mujoco_all_sql.py
. The default reference distribution is Beta distribution. The reference distribution option supports "beta" (Beta distribution) and "norm" (Gaussian distribution).
Swimmer (rllab) | Humanoid (rllab) | Hopper-v1 | Half-cheetah-v1 | Ant-v1 | Walker-v1 |
---|---|---|---|---|---|
Note: Humanoid has a higher action space dimension, making adversarial learning instable; More future work is needed to make Humanoid run better.
To show that RAS can draw samples when the support is bounded, we apply it to sample from the distributions with the support [c1,c2]. Please see the code at constrained_sampling
.
RAS: Beta ref. | RAS: Gaussian ref. | SVGD | Amortized SVGD |
---|---|---|---|
Please note that RAS Gaussian ref. recovers AVB-AC (Adversarial Variational Bayes with Adaptive Contrast).
An entropy term H(x) is approximated to stablize adversarial training. As examples, we consider to regularize the following GAN variants: GAN
, SN-GAN
, D2GAN
and Unrolled-GAN
. All entropy-regularization code is at entropy
:
To run:
python run_test.py --model gan_cc
It takes the following options (among others) as arguments:
- The
--model
specifies the GAN variant to apply the entropy regularizer. It supports [gan
,d2gan
,ALLgan
,SNgan
]; defaultgan
. To apply entropy regularizer, change the argument of--model
as [gan_cc
,d2gan_cc
,ALLgan_cc
,SNgan_cc
]
Entropy regularizer on 8-GMM toy dataset | SN-GAN | SN-GAN + Entropy |
---|---|---|
Jupyter notebooks in plots
folders are used to reproduce paper figure results.
Note that without modification, we have copyed our extracted results into the notebook, and script will output figures in the paper. If you've run your own training and wish to plot results, you'll have to organize your results in the same format instead.
Please drop us (Chunyuan, Ke, Jianqiao or Guoyin) a line if you have any questions.