TensorFlow code for Generative Adversarial Training for Markov Chains (ICLR 2017 Workshop Track).
Work by Jiaming Song, Shengjia Zhao and Stefano Ermon.
Running the code requires some preprocessing. Namely, we transform the data to TensorFlow Records file to maximize speed (as suggested by TensorFlow).
The data used for training is here.
Download and place the directory in ~/data/mnist_tfrecords
.
(This can be easily done by using a symlink or you can change the path in file models/mnist/__init__.py
)
The data used for training is here.
Download and place the directory in ~/data/celeba_tfrecords
.
python mgan.py [data] [model] -b [B] -m [M] -d [critic iterations] --gpus [gpus]
where B
defines the steps from noise to data, M
defines the steps from data to data, and [gpus]
defines the CUDA_VISIBLE_DEVICES
environment variable.
python mgan.py mnist mlp -b 4 -m 3 -d 7 --gpus [gpus]
Without shortcut connections:
python mgan.py celeba conv -b 4 -m 3 -d 7 --gpus [gpus]
With shortcut connections (will observe a much slower transition):
python mgan.py celeba conv_res -b 4 -m 3 -d 7 --gpus [gpus]
It is easy to define your own problem and run experiments.
- Create a folder
data
under themodels
directory, and definedata_sampler
andnoise_sampler
in__init__.py
. - Create a file
model.py
under themodels/data
directory, and define the following:class TransitionFunction(TransitionBase)
(Generator)class Discriminator(DiscriminatorBase)
(Discriminator)def visualizer(model, name)
(If you need to generate figures)epoch_size
andlogging_freq
- That's it!
Each row is from a single chain, where we sample for 50 time steps.
a-nice-mc: adversarial training for efficient MCMC kernels, which is based on this project.
If you use this code for your research, please cite our paper:
@article{song2017generative,
title={Generative Adversarial Training for Markov Chains},
author={Song, Jiaming and Zhao, Shengjia and Ermon, Stefano},
journal={ICLR 2017 (Workshop Track)},
year={2017}
}
Code for the Pairwise Discriminator is not available at this moment; I will add that when I have the time.