Skip to content

Code for "Generative Adversarial Training for Markov Chains" (ICLR 2017 Workshop)

License

Notifications You must be signed in to change notification settings

ermongroup/markov-chain-gan

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Markov Chain GAN (MGAN)

TensorFlow code for Generative Adversarial Training for Markov Chains (ICLR 2017 Workshop Track).

Work by Jiaming Song, Shengjia Zhao and Stefano Ermon.


Preprocessing

Running the code requires some preprocessing. Namely, we transform the data to TensorFlow Records file to maximize speed (as suggested by TensorFlow).

MNIST

The data used for training is here. Download and place the directory in ~/data/mnist_tfrecords.

(This can be easily done by using a symlink or you can change the path in file models/mnist/__init__.py)

CelebA

The data used for training is here. Download and place the directory in ~/data/celeba_tfrecords.


Running Experiments

python mgan.py [data] [model] -b [B] -m [M] -d [critic iterations] --gpus [gpus]

where B defines the steps from noise to data, M defines the steps from data to data, and [gpus] defines the CUDA_VISIBLE_DEVICES environment variable.

MNIST

python mgan.py mnist mlp -b 4 -m 3 -d 7 --gpus [gpus]

CelebA

Without shortcut connections:

python mgan.py celeba conv -b 4 -m 3 -d 7 --gpus [gpus]

With shortcut connections (will observe a much slower transition):

python mgan.py celeba conv_res -b 4 -m 3 -d 7 --gpus [gpus]

Custom Experiments

It is easy to define your own problem and run experiments.

  • Create a folder data under the models directory, and define data_sampler and noise_sampler in __init__.py.
  • Create a file model.py under the models/data directory, and define the following:
    • class TransitionFunction(TransitionBase) (Generator)
    • class Discriminator(DiscriminatorBase) (Discriminator)
    • def visualizer(model, name) (If you need to generate figures)
    • epoch_size and logging_freq
  • That's it!

Figures

Each row is from a single chain, where we sample for 50 time steps.

MNIST

MNIST MLP

CelebA

Without shortcut connections: CelebA 1-layer conv

With shortcut connections: CelebA 1-layer conv with shortcuts

Related Projects

a-nice-mc: adversarial training for efficient MCMC kernels, which is based on this project.

Citation

If you use this code for your research, please cite our paper:

@article{song2017generative,
  title={Generative Adversarial Training for Markov Chains},
  author={Song, Jiaming and Zhao, Shengjia and Ermon, Stefano},
  journal={ICLR 2017 (Workshop Track)},
  year={2017}
}

Contact

tsong@cs.stanford.edu

Code for the Pairwise Discriminator is not available at this moment; I will add that when I have the time.

Releases

No releases published

Packages

No packages published

Languages