Awesome
Markov Chain GAN (MGAN)
TensorFlow code for Generative Adversarial Training for Markov Chains (ICLR 2017 Workshop Track).
Work by Jiaming Song, Shengjia Zhao and Stefano Ermon.
<br/>Preprocessing
Running the code requires some preprocessing. Namely, we transform the data to TensorFlow Records file to maximize speed (as suggested by TensorFlow).
MNIST
The data used for training is here.
Download and place the directory in ~/data/mnist_tfrecords
.
(This can be easily done by using a symlink or you can change the path in file models/mnist/__init__.py
)
CelebA
The data used for training is here.
Download and place the directory in ~/data/celeba_tfrecords
.
Running Experiments
python mgan.py [data] [model] -b [B] -m [M] -d [critic iterations] --gpus [gpus]
where B
defines the steps from noise to data, M
defines the steps from data to data, and [gpus]
defines the CUDA_VISIBLE_DEVICES
environment variable.
MNIST
python mgan.py mnist mlp -b 4 -m 3 -d 7 --gpus [gpus]
CelebA
Without shortcut connections:
python mgan.py celeba conv -b 4 -m 3 -d 7 --gpus [gpus]
With shortcut connections (will observe a much slower transition):
python mgan.py celeba conv_res -b 4 -m 3 -d 7 --gpus [gpus]
Custom Experiments
It is easy to define your own problem and run experiments.
- Create a folder
data
under themodels
directory, and definedata_sampler
andnoise_sampler
in__init__.py
. - Create a file
model.py
under themodels/data
directory, and define the following:class TransitionFunction(TransitionBase)
(Generator)class Discriminator(DiscriminatorBase)
(Discriminator)def visualizer(model, name)
(If you need to generate figures)epoch_size
andlogging_freq
- That's it!
Figures
Each row is from a single chain, where we sample for 50 time steps.
MNIST
CelebA
Without shortcut connections:
With shortcut connections:
Related Projects
a-nice-mc: adversarial training for efficient MCMC kernels, which is based on this project.
Citation
If you use this code for your research, please cite our paper:
@article{song2017generative,
title={Generative Adversarial Training for Markov Chains},
author={Song, Jiaming and Zhao, Shengjia and Ermon, Stefano},
journal={ICLR 2017 (Workshop Track)},
year={2017}
}
Contact
Code for the Pairwise Discriminator is not available at this moment; I will add that when I have the time.