Home

Awesome

Deep Convolutional Generative Adversarial Networks in Tensorflow

Descriptions

This is my Tensorflow implementation of Deep Convolutional Generative Adversarial Networks in Tensorflow proposed in the paper Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks. The main contribution comes from the tricks which stablize the training of Generative Adversarial Networks. The proposed architecture is as followed.

<img src="figure/dcgan.png" height="300"/>

The implemented model is trained and tested on four publicly available datasets: MNIST, Fashion MNIST, SVHN, and CIFAR-10.

Note that this implementation only follows the main architecture of the original paper while differing a lot in implementation details such as hyperparameters, applied optimizer, etc. Also, some useful training tricks applied to this implementation are stated at the end of this README.

*This code is still being developed and subject to change.

Prerequisites

Usage

Download datasets

$ python download.py --dataset MNIST Fashion SVHN CIFAR10

Train models with downloaded datasets:

$ python trainer.py --dataset Fashion
$ python trainer.py --dataset MNIST
$ python trainer.py --dataset SVHN
$ python trainer.py --dataset CIFAR10

Test models with saved checkpoints:

$ python evaler.py --dataset Fashion --checkpoint ckpt_dir
$ python evaler.py --dataset MNIST --checkpoint ckpt_dir
$ python evaler.py --dataset SVHN --checkpoint ckpt_dir
$ python evaler.py --dataset CIFAR10 --checkpoint ckpt_dir

The ckpt_dir should be like: train_dir/default-MNIST_lr_0.0001-20170101-123456/model-1001

Train and test your own datasets:

$ mkdir datasets/YOUR_DATASET
$ python trainer.py --dataset YOUR_DATASET
$ python evaler.py --dataset YOUR_DATASET --train_dir dir

The dir should be like: train_dir/default-MNIST_lr_0.0001-20170101-123456/model-1001

Results

Fashion MNIST

<img src="figure/result/fashion/samples.png" height="250"/> <img src="figure/result/fashion/training.gif" height="250"/>

MNIST

<img src="figure/result/mnist/samples.png" height="250"/> <img src="figure/result/mnist/training.gif" height="250"/>

SVHN

<img src="figure/result/svhn/samples.png" height="250"/> <img src="figure/result/svhn/training.gif" height="250"/>

CIFAR-10

<img src="figure/result/cifar10/samples.png" height="250"/> <img src="figure/result/cifar10/training.gif" height="250"/>

Training details

MNIST

D_loss_real

<img src="figure/result/mnist/d_loss_real.png" height="200"/>

D_loss_fake

<img src="figure/result/mnist/d_loss_fake.png" height="200"/>

D_loss (total loss)

<img src="figure/result/mnist/d_loss.png" height="200"/>

G_loss

<img src="figure/result/mnist/g_loss.png" height="200"/>

SVHN

D_loss_real

<img src="figure/result/svhn/d_loss_real.png" height="200"/>

D_loss_fake

<img src="figure/result/svhn/d_loss_fake.png" height="200"/>

D_loss (total loss)

<img src="figure/result/svhn/d_loss.png" height="200"/>

G_loss

<img src="figure/result/svhn/g_loss.png" height="200"/>

CIFAR-10

D_loss_real

<img src="figure/result/cifar10/d_loss_real.png" height="200"/>

D_loss_fake

<img src="figure/result/cifar10/d_loss_fake.png" height="200"/>

D_loss (total loss)

<img src="figure/result/cifar10/d_loss.png" height="200"/>

G_loss

<img src="figure/result/cifar10/g_loss.png" height="200"/>

Training tricks

Related works

Author

Shao-Hua Sun / @shaohua0116