Home

Awesome

Semi-supervised learning GAN in Tensorflow

As part of the implementation series of Joseph Lim's group at USC, our motivation is to accelerate (or sometimes delay) research in the AI community by promoting open-source projects. To this end, we implement state-of-the-art research papers, and publicly share them with concise reports. Please visit our group github site for other projects.

This project is implemented by Shao-Hua Sun and the codes have been reviewed by Jiayuan Mao before being published.

Descriptions

This project is a Tensorflow implementation of Semi-supervised Learning Generative Adversarial Networks proposed in the paper Improved Techniques for Training GANs. The intuition is exploiting the samples generated by GAN generators to boost the performance of image classification tasks by improving generalization.

In sum, the main idea is training a network playing both the roles of a classifier performing image classification task as well as a discriminator trained to distinguish generated samples produced by a generator from the real data. To be more specific, the discriminator/classifier takes an image as input and classified it into n+1 classes, where n is the number of classes of a classification task. True samples are classified into the first n classes and generated samples are classified into the n+1-th class, as shown in the figure below.

<img src="figure/ssgan.png" height="300"/>

The loss of this multi-task learning framework can be decomposed into the supervised loss

<img src="figure/s_loss.png" height="25"/>,

and the GAN loss of a discriminator

<img src="figure/gan_loss.png" height="25"/>,

During the training phase, we jointly minimize the total loss obtained by simply combining the two losses together.

The implemented model is trained and tested on three publicly available datasets: MNIST, SVHN, and CIFAR-10.

Note that this implementation only follows the main idea of the original paper while differing a lot in implementation details such as model architectures, hyperparameters, applied optimizer, etc. Also, some useful training tricks applied to this implementation are stated at the end of this README.

*This code is still being developed and subject to change.

Prerequisites

Usage

Download datasets with:

$ python download.py --dataset MNIST SVHN CIFAR10

Train models with downloaded datasets:

$ python trainer.py --dataset MNIST
$ python trainer.py --dataset SVHN
$ python trainer.py --dataset CIFAR10

Test models with saved checkpoints:

$ python evaler.py --dataset MNIST --checkpoint ckpt_dir
$ python evaler.py --dataset SVHN --checkpoint ckpt_dir
$ python evaler.py --dataset CIFAR10 --checkpoint ckpt_dir

The ckpt_dir should be like: train_dir/default-MNIST_lr_0.0001_update_G5_D1-20170101-194957/model-1001

Train and test your own datasets:

$ mkdir datasets/YOUR_DATASET
$ python trainer.py --dataset YOUR_DATASET
$ python evaler.py --dataset YOUR_DATASET

Results

MNIST

<img src="figure/result/mnist/samples.png" height="250"/> <img src="figure/result/mnist/training.gif" height="250"/>

SVHN

<img src="figure/result/svhn/samples.png" height="250"/> <img src="figure/result/svhn/training.gif" height="250"/>

CIFAR-10

<img src="figure/result/cifar10/samples.png" height="250"/> <img src="figure/result/cifar10/training.gif" height="250"/>

Training details

MNIST

<img src="figure/result/mnist/s_loss.png" height="200"/>

D_loss_real

<img src="figure/result/mnist/d_loss_real.png" height="200"/>

D_loss_fake

<img src="figure/result/mnist/d_loss_fake.png" height="200"/>

D_loss (total loss)

<img src="figure/result/mnist/d_loss.png" height="200"/>

G_loss

<img src="figure/result/mnist/g_loss.png" height="200"/> <img src="figure/result/mnist/accuracy.png" height="200"/>

SVHN

<img src="figure/result/svhn/s_loss.png" height="200"/>

D_loss_real

<img src="figure/result/svhn/d_loss_real.png" height="200"/>

D_loss_fake

<img src="figure/result/svhn/d_loss_fake.png" height="200"/>

D_loss (total loss)

<img src="figure/result/svhn/d_loss.png" height="200"/>

G_loss

<img src="figure/result/svhn/g_loss.png" height="200"/> <img src="figure/result/svhn/accuracy.png" height="200"/>

CIFAR-10

<img src="figure/result/cifar10/s_loss.png" height="200"/>

D_loss_real

<img src="figure/result/cifar10/d_loss_real.png" height="200"/>

D_loss_fake

<img src="figure/result/cifar10/d_loss_fake.png" height="200"/>

D_loss (total loss)

<img src="figure/result/cifar10/d_loss.png" height="200"/>

G_loss

<img src="figure/result/cifar10/g_loss.png" height="200"/> <img src="figure/result/cifar10/accuracy.png" height="200"/>

Training tricks

Related works

Acknowledgement

Part of codes is from an unpublished project with Jongwook Choi