Awesome
WGAN-TensorFlow
This repository is a Tensorflow implementation of Martin Arjovsky's Wasserstein GAN, arXiv:1701.07875v3.
<p align='center'> <img src="https://user-images.githubusercontent.com/37034031/43870865-b795a83e-9bb4-11e8-8005-461951b3d7b7.png" width=700) </p>Requirements
- tensorflow 1.9.0
- python 3.5.3
- numpy 1.14.2
- pillow 5.0.0
- scipy 0.19.0
- matplotlib 2.2.2
Applied GAN Structure
- Generator (DCGAN)
- Critic (DCGAN)
Generated Images
- MNIST
- CelebA
Documentation
Download Dataset
MNIST dataset will be downloaded automatically if in a specific folder there are no dataset. Use the following command to download CelebA
dataset and copy the `CelebA' dataset on the corresponding file as introduced in Directory Hierarchy information.
python download2.py celebA
Directory Hierarchy
.
│ WGAN
│ ├── src
│ │ ├── dataset.py
│ │ ├── download2.py
│ │ ├── main.py
│ │ ├── solver.py
│ │ ├── tensorflow_utils.py
│ │ ├── utils.py
│ │ └── wgan.py
│ Data
│ ├── celebA
│ └── mnist
src: source codes of the WGAN
Implementation Details
Implementation uses TensorFlow to train the WGAN. Same generator and critic networks are used as described in Alec Radford's paper. WGAN does not use a sigmoid function in the last layer of the critic, a log-likelihood in the cost function. Optimizer is used RMSProp instead of Adam.
Training WGAN
Use main.py
to train a WGAN network. Example usage:
python main.py --is_train=true --dataset=[celebA|mnist]
gpu_index
: gpu index, default:0
batch_size
: batch size for one feed forward, default:64
dataset
: dataset name for choice [celebA|mnist], default:celebA
is_train
: training or inference mode, default:False
learning_rate
: initial learning rate, default:0.00005
num_critic
: the number of iterations of the critic per generator iteration, default:5
z_dim
: dimension of z vector, default:100
iters
: number of interations, default:100000
print_freq
: print frequency for loss, default:50
save_freq
: save frequency for model, default:10000
sample_freq
: sample frequency for saving image, default:200
sample_size
: sample size for check generated image quality, default:64
load_model
: folder of save model that you wish to test, (e.g. 20180704-1736). default:None
Wasserstein Distance During Training
- MNIST
- CelebA
Evaluate WGAN
Use main.py
to evaluate a WGAN network. Example usage:
python main.py --is_train=false --load_model=folder/you/wish/to/test/e.g./20180704-1746
Please refer to the above arguments.
Citation
@misc{chengbinjin2018wgan,
author = {Cheng-Bin Jin},
title = {WGAN-tensorflow},
year = {2018},
howpublished = {\url{https://github.com/ChengBinJin/WGAN-TensorFlow}},
note = {commit xxxxxxx}
}
Attributions/Thanks
- This project borrowed some code from wiseodd
- Some readme formatting was borrowed from Logan Engstrom
License
Copyright (c) 2018 Cheng-Bin Jin. Contact me for commercial use (or rather any use that is not academic research) (email: sbkim0407@gmail.com). Free for research use, as long as proper attribution is given and this copyright notice is retained.