

Adversarial Auto-Encoder for MNIST

An implementation of adversarial auto-encoder (AAE) for MNIST descripbed in the paper:

Implementation Details

The paper suggest various ways of using AAE.

Only results on 'Incorporatiing Label Information in the Adversarial Regularization' are given here.

Target Distributions

Three types of prior distrubtion are considered.

The following graphs can be obtained with command:

python test_prior_type.py --prior_type <type>
<table align='center'> <tr align='center'> <td> mixGaussian </td> <td> swiss_roll </td> <td> normal </td> </tr> <tr> <td><img src = 'samples/target_prior_distribution_mixture_of_gaussian.png' height = '250px'> <td><img src = 'samples/target_prior_distribution_swiss_roll.png' height = '250px'> <td><img src = 'samples/target_prior_distribution_normal.png' height = '250px'> </tr> </table>


Leveraging label information to better regularize the hidden code in Figure 4 in the paper.

Prior distribution type : a mixture of 10 2-D Gaussians

The following results can be reproduced with command:

python run_main.py --prior_type mixGaussian
<table align='center'> <tr align='center'> <td> Learned MNIST manifold (20 Epochs) </td> <td> Distribution of labeled data (20 Epochs) </td> </tr> <tr> <td><img src = 'samples/mixGaussian/PMLR_epoch_19.jpg' height = '400px'> <td><img src = 'samples/mixGaussian/PMLR_map_epoch_19.jpg' height = '400px'> </tr> </table>

Prior distribution type : a swiss roll distribution

The following results can be reproduced with command:

python run_main.py --prior_type swiss_roll
<table align='center'> <tr align='center'> <td> Learned MNIST manifold (20 Epochs) </td> <td> Distribution of labeled data (20 Epochs) </td> </tr> <tr> <td><img src = 'samples/swiss_roll/PMLR_epoch_19.jpg' height = '400px'> <td><img src = 'samples/swiss_roll/PMLR_map_epoch_19.jpg' height = '400px'> </tr> </table>

Prior distribution type : a normal distribution (not suggested in the paper)

The following results can be reproduced with command:

python run_main.py --prior_type normal
<table align='center'> <tr align='center'> <td> Learned MNIST manifold (20 Epochs) </td> <td> Distribution of labeled data (20 Epochs) </td> </tr> <tr> <td><img src = 'samples/normal/PMLR_epoch_19.jpg' height = '400px'> <td><img src = 'samples/normal/PMLR_map_epoch_19.jpg' height = '400px'> </tr> </table>



  1. Tensorflow
  2. Python packages : numpy, scipy, PIL(or Pillow), matplotlib


python run_main.py --prior_type <type>


Required :

Optional :


This implementation has been tested with Tensorflow 1.2.1 on Windows 10.