Home

Awesome

Generative Latent Optimization in Tensorflow

As part of the implementation series of Joseph Lim's group at USC, our motivation is to accelerate (or sometimes delay) research in the AI community by promoting open-source projects. To this end, we implement state-of-the-art research papers, and publicly share them with concise reports. Please visit our group github site for other projects.

This project is implemented by Shao-Hua Sun and the codes have been reviewed by Jiayuan Mao before being published.

Descriptions

This project is a Tensorflow implementation of Generative Latent Optimization (GLO) proposed in the paper Optimizing the Latent Space of Generative Networks. GLO is an exciting new generative model which enjoys many of the desirable properties of GANs including modeling data distributions, generating realistic samples, interpretable latent space, but more importantly, it doesn't suffer from unstable adversarial training dynamics.

GLO learns to map learnable latent vectors to samples in a target dataset by minimizing a reconstruction loss. During the traning phase, optimizing the parameters of the generator and tuning the corresponding latent vectors are alternatively performed. When converged, the model is able to generate novel samples given latent vectors sampled from the distribution. The illustration of the GLO framework is as follows.

<img src="figure/glo.png" height="300"/>

The implemented model is trained and tested on three publicly available datasets: MNIST, SVHN, and CIFAR-10. Model's ability to recontruct samples, generate new samples, and interploate in a latent space are tested in this implementation.

Note that this implementation only follows the main idea of the original paper while differing a lot in implementation details such as model architectures, hyperparameters, applied optimizer, etc. Particularly, the updating procedure applied to latent vectors is mainly implemented based on my conjectures.

*This code is still being developed and subject to change.

Prerequisites

Usage

Datasets

Download datasets with specified settings. For examples:

$ python download.py --datasets MNIST --distribution PCA --dimension 10
$ python download.py --datasets SVHN --distribution Uniform --dimension 25
$ python download.py --datasets CIFAR10 --distribution Gaussian --dimension 35

Note that distribution indicates the initial distribution of the latent space and dimension specifies the dimension of each latent vector.

Train the models

Train models with downloaded datasets. For example:

$ python trainer.py --dataset MNIST --alpha 5 --dump_result --batch_size 32
$ python trainer.py --dataset SVHN --alpha 10 --lr_weight_decay
$ python trainer.py --dataset CIFAR10 --alpha 10 --learning_rate 1e-5

Note that alpha indicates the weight of updating latent vectors at each iteration.

Test the models

Test models with saved checkpoints:

$ python evaler.py --dataset MNIST --checkpoint ckpt_dir --prefix mnist --reconstruct --generate
$ python evaler.py --dataset SVHN --checkpoint ckpt_dir --prefix svhn  --interpolate
$ python evaler.py --dataset CIFAR10 --checkpoint ckpt_dir --prefix cifar 

There are three task options: reconstruction (--reconstruct), sample generation (--generate), and sample interpolation (--interpolate).

The ckpt_dir should be like: train_dir/default-MNIST_lr_0.0001-20170101-123456/model-1001

Train and test your own datasets:

$ mkdir datasets/YOUR_DATASET
$ python evaler.py --dataset YOUR_DATASET

Results

SVHN

<img src="figure/result/reconstruct_svhn.png" height="250"/> <img src="figure/result/generate_svhn.png" height="250"/> <img src="figure/result/interpolate_svhn.png" height="250"/> <img src="figure/result/training_svhn.gif" height="250"/>

MNIST

<img src="figure/result/reconstruct_mnist.png" height="250"/> <img src="figure/result/generate_mnist.png" height="250"/> <img src="figure/result/interpolate_mnist.png" height="250"/> <img src="figure/result/training_mnist.gif" height="250"/>

CIFAR-10

<img src="figure/result/reconstruct_cifar.png" height="250"/> <img src="figure/result/generate_cifar.png" height="250"/> <img src="figure/result/interpolate_cifar.png" height="250"/> <img src="figure/result/training_cifar.gif" height="250"/>

Training details

SVHN

<img src="figure/result/loss_svhn.png" height="200"/>

MNIST

<img src="figure/result/loss_mnist.png" height="200"/>

CIFAR-10

<img src="figure/result/loss_cifar.png" height="200"/>

Related works

Author

Shao-Hua Sun / @shaohua0116 @ Joseph Lim's research lab @ USC