Home

Awesome

alt text


CircleCI codecov PyPI version Documentation Status License: MIT

About | Documentation | Tutorial | Gallery | Paper

Mimicry is a lightweight PyTorch library aimed towards the reproducibility of GAN research.

Comparing GANs is often difficult - mild differences in implementations and evaluation methodologies can result in huge performance differences. Mimicry aims to resolve this by providing: (a) Standardized implementations of popular GANs that closely reproduce reported scores; (b) Baseline scores of GANs trained and evaluated under the same conditions; (c) A framework for researchers to focus on implementation of GANs without rewriting most of GAN training boilerplate code, with support for multiple GAN evaluation metrics.

We provide a model zoo and set of baselines to benchmark different GANs of the same model size trained under the same conditions, using multiple metrics. To ensure reproducibility, we verify scores of our implemented models against reported scores in literature.


Installation

The library can be installed with:

pip install git+https://github.com/kwotsin/mimicry.git

See also setup information for more.

Example Usage

Training a popular GAN like SNGAN that reproduces reported scores can be done as simply as:

import torch
import torch.optim as optim
import torch_mimicry as mmc
from torch_mimicry.nets import sngan

# Data handling objects
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10')
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=64, shuffle=True, num_workers=4)

# Define models and optimizers
netG = sngan.SNGANGenerator32().to(device)
netD = sngan.SNGANDiscriminator32().to(device)
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))

# Start training
trainer = mmc.training.Trainer(
    netD=netD,
    netG=netG,
    optD=optD,
    optG=optG,
    n_dis=5,
    num_steps=100000,
    lr_decay='linear',
    dataloader=dataloader,
    log_dir='./log/example',
    device=device)
trainer.train()

# Evaluate fid
mmc.metrics.evaluate(
    metric='fid',
    log_dir='./log/example',
    netG=netG,
    dataset='cifar10',
    num_real_samples=50000,
    num_fake_samples=50000,
    evaluate_step=100000,
    device=device)

Example outputs:

>>> INFO: [Epoch 1/127][Global Step: 10/100000]
| D(G(z)): 0.5941
| D(x): 0.9303
| errD: 1.4052
| errG: -0.6671
| lr_D: 0.0002
| lr_G: 0.0002
| (0.4550 sec/idx)
^CINFO: Saving checkpoints from keyboard interrupt...
INFO: Training Ended

Tensorboard visualizations:

tensorboard --logdir=./log/example

See further details in example script, as well as a detailed tutorial on implementing a custom GAN from scratch.

Further Guides

<div id="baselines"></div>

Baselines | Model Zoo

For a fair comparison, we train all models under the same training conditions for each dataset, each implemented using ResNet backbones of the same architectural capacity. We train our models with the Adam optimizer using the popular hyperparameters (β<sub>1</sub>, β<sub>2</sub>) = (0.0, 0.9). n<sub>dis</sub> represents the number of discriminator update steps per generator update step, and n<sub>iter</sub> is simply the number of training iterations.

Models

Abbrev.NameType*
DCGANDeep Convolutional GANUnconditional
WGAN-GPWasserstein GAN with Gradient PenaltyUnconditional
SNGANSpectral Normalization GANUnconditional
cGAN-PDConditional GAN with Projection DiscriminatorConditional
SSGANSelf-supervised GANUnconditional
InfoMax-GANInfomax-GANUnconditional

*Conditional GAN scores are only reported for labelled datasets.

Metrics

MetricMethod
Inception Score (IS)*50K samples at 10 splits
Fréchet Inception Distance (FID)50K real/generated samples
Kernel Inception Distance (KID)50K real/generated samples, averaged over 10 splits.

*Inception Score can be a poor indicator of GAN performance, as it does not measure diversity and is not domain agnostic. This is why certain datasets with only a single class (e.g. CelebA and LSUN-Bedroom) will perform poorly when using this metric.

Datasets

DatasetSplitResolution
CIFAR-10Train32 x 32
CIFAR-100Train32 x 32
ImageNetTrain32 x 32
STL-10Unlabeled48 x 48
CelebAAll64 x 64
CelebAAll128 x 128
LSUN-BedroomTrain128 x 128
ImageNetTrain128 x 128

CelebA

Paper | Dataset

Training Parameters

ResolutionBatch SizeLearning Rateβ<sub>1</sub>β<sub>2</sub>Decay Policyn<sub>dis</sub>n<sub>iter</sub>
128 x 128642e-40.00.9None2100K
64 x 64642e-40.00.9Linear5100K

Results

ResolutionModelISFIDKIDCheckpointCode
128 x 128SNGAN2.72 ± 0.0112.93 ± 0.040.0076 ± 0.0001netG.pthsngan_128.py
128 x 128SSGAN2.63 ± 0.0115.18 ± 0.100.0101 ± 0.0001netG.pthssgan_128.py
128 x 128InfoMax-GAN2.84 ± 0.019.50 ± 0.040.0063 ± 0.0001netG.pthinfomax_gan_128.py
64 x 64SNGAN2.68 ± 0.015.71 ± 0.020.0033 ± 0.0001netG.pthsngan_64.py
64 x 64SSGAN2.67 ± 0.016.03 ± 0.040.0036 ± 0.0001netG.pthssgan_64.py
64 x 64InfoMax-GAN2.68 ± 0.015.71 ± 0.060.0033 ± 0.0001netG.pthinfomax_gan_64.py

LSUN-Bedroom

Paper | Dataset

Training Parameters

ResolutionBatch SizeLearning Rateβ<sub>1</sub>β<sub>2</sub>Decay Policyn<sub>dis</sub>n<sub>iter</sub>
128 x 128642e-40.00.9Linear2100K

Results

ResolutionModelISFIDKIDCheckpointCode
128 x 128SNGAN2.30 ± 0.0125.87 ± 0.030.0141 ± 0.0001netG.pthsngan_128.py
128 x 128SSGAN2.12 ± 0.0112.02 ± 0.070.0077 ± 0.0001netG.pthssgan_128.py
128 x 128InfoMax-GAN2.22 ± 0.0112.13 ± 0.160.0080 ± 0.0001netG.pthinfomax_gan_128.py

STL-10

Paper | Dataset

Training Parameters

ResolutionBatch SizeLearning Rateβ<sub>1</sub>β<sub>2</sub>Decay Policyn<sub>dis</sub>n<sub>iter</sub>
48 x 48642e-40.00.9Linear5100K

Results

ResolutionModelISFIDKIDCheckpointCode
48 x 48WGAN-GP8.55 ± 0.0243.01 ± 0.190.0440 ± 0.0003netG.pthwgan_gp_48.py
48 x 48SNGAN8.04 ± 0.0739.56 ± 0.100.0369 ± 0.0002netG.pthsngan_48.py
48 x 48SSGAN8.25 ± 0.0637.06 ± 0.190.0332 ± 0.0004netG.pthssgan_48.py
48 x 48InfoMax-GAN8.54 ± 0.1235.52 ± 0.100.0326 ± 0.0002netG.pthinfomax_gan_48.py

ImageNet

Paper | Dataset

Training Parameters

ResolutionBatch SizeLearning Rateβ<sub>1</sub>β<sub>2</sub>Decay Policyn<sub>dis</sub>n<sub>iter</sub>
32 x 32642e-40.00.9Linear5100K
128 x 128642e-40.00.9None5450k

Results

ResolutionModelISFIDKIDCheckpointCode
128 x 128SNGAN13.05 ± 0.0565.74 ± 0.310.0663 ± 0.0004netG.pthsngan_128.py
128 x 128SSGAN13.30 ± 0.0362.48 ± 0.310.0616 ± 0.0004netG.pthssgan_128.py
128 x 128InfoMax-GAN13.68 ± 0.0658.91 ± 0.140.0579 ± 0.0004netG.pthinfomax_gan_128.py
32 x 32SNGAN8.97 ± 0.1223.04 ± 0.060.0157 ± 0.0002netG.pthsngan_32.py
32 x 32cGAN-PD9.08 ± 0.1721.17 ± 0.050.0145 ± 0.0002netG.pthcgan_pd_32.py
32 x 32SSGAN9.11 ± 0.1221.79 ± 0.090.0152 ± 0.0002netG.pthssgan_32.py
32 x 32InfoMax-GAN9.04 ± 0.1020.68 ± 0.020.0149 ± 0.0001netG.pthinfomax_gan_32.py

CIFAR-10

Paper | Dataset

Training Parameters

ResolutionBatch SizeLearning Rateβ<sub>1</sub>β<sub>2</sub>Decay Policyn<sub>dis</sub>n<sub>iter</sub>
32 x 32642e-40.00.9Linear5100K

Results

ResolutionModelISFIDKIDCheckpointCode
32 x 32WGAN-GP7.33 ± 0.0222.29 ± 0.060.0204± 0.0004netG.pthwgan_gp_32.py
32 x 32SNGAN7.97 ± 0.0616.77 ± 0.040.0125 ± 0.0001netG.pthsngan_32.py
32 x 32cGAN-PD8.25 ± 0.1310.84 ± 0.030.0070 ± 0.0001netG.pthcgan_pd_32.py
32 x 32SSGAN8.17 ± 0.0614.65 ± 0.040.0101 ± 0.0002netG.pthssgan_32.py
32 x 32InfoMax-GAN8.08± 0.0815.12 ± 0.100.0112 ± 0.0001netG.pthinfomax_gan_32.py

CIFAR-100

Paper | Dataset

Training Parameters

ResolutionBatch SizeLearning Rateβ<sub>1</sub>β<sub>2</sub>Decay Policyn<sub>dis</sub>n<sub>iter</sub>
32 x 32642e-40.00.9Linear5100K

Results

ResolutionModelISFIDKIDCheckpointCode
32 x 32SNGAN7.57 ± 0.1122.61 ± 0.060.0156 ± 0.0003netG.pthsngan_32.py
32 x 32cGAN-PD8.92 ± 0.0714.16 ± 0.010.0085 ± 0.0002netG.pthcgan_pd_32.py
32 x 32SSGAN7.56 ± 0.0722.18 ± 0.100.0161 ± 0.0002netG.pthssgan_32.py
32 x 32InfoMax-GAN7.86 ± 0.1018.94 ± 0.130.0135 ± 0.0004netG.pthinfomax_gan_32.py

<div id="reproducibility"></div>

Reproducibility

To verify our implementations, we reproduce reported scores in literature by re-implementing the models with the same architecture, training them under the same conditions and evaluate them on CIFAR-10 using the exact same methodology for computing FID.

As FID produces highly biased estimates (where using larger samples lead to a lower score), we reproduce the scores using the same sample sizes, where n<sub>real</sub> and n<sub>fake</sub> refers to the number of real and fake images used respectively for computing FID.

MetricModelScoreReported Scoren<sub>real</sub>n<sub>fake</sub>CheckpointCode
FIDDCGAN28.95 ± 0.4228.12 [4]10K10KnetG.pthdcgan_cifar.py
FIDWGAN-GP26.08 ± 0.1229.3 <sup></sup> [6]50K50KnetG.pthwgan_gp_32.py
FIDSNGAN23.90 ± 0.2021.7 ± 0.21 [1]10K5KnetG.pthsngan_32.py
FIDcGAN-PD17.84 ± 0.1717.5 [2]10K5KnetG.pthcgan_pd_32.py
FIDSSGAN17.61 ± 0.1417.88 ± 0.64 [3]10K10KnetG.pthssgan_32.py
FIDInfoMax-GAN17.14 ± 0.2017.14 ± 0.20 [5]50K10KnetG.pthinfomax_gan_32.py

<sup></sup> Best FID was reported at 53K steps, but we find our score can improve till 100K steps to achieve 23.13 ± 0.13.

Citation

If you have found this work useful, please consider citing our work:

@article{lee2020mimicry,
    title={Mimicry: Towards the Reproducibility of GAN Research},
    author={Kwot Sin Lee and Christopher Town},
    booktitle={CVPR Workshop on AI for Content Creation},
    year={2020},
}

For citing InfoMax-GAN:

@InProceedings{Lee_2021_WACV,
    author    = {Lee, Kwot Sin and Tran, Ngoc-Trung and Cheung, Ngai-Man},
    title     = {InfoMax-GAN: Improved Adversarial Image Generation via Information Maximization and Contrastive Learning},
    booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
    month     = {January},
    year      = {2021},
    pages     = {3942-3952}
}

References

[1] Spectral Normalization for Generative Adversarial Networks

[2] cGANs with Projection Discriminator

[3] Self-Supervised GANs via Auxiliary Rotation Loss

[4] A Large-Scale Study on Regularization and Normalization in GANs

[5] InfoMax-GAN: Improved Adversarial Image Generation via Information Maximization and Contrastive Learning

[6] GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium