Home

Awesome

Gradient Normalization for Generative Adversarial Networks

Yi-Lun Wu, Hong-Han Shuai, Zhi-Rui Tam, Hong-Yu Chiu

Paper: https://arxiv.org/abs/2109.02235

This is the official implementation of Gradient Normalized GAN (GN-GAN).

Requirements

Datasets

Preprocessing Datasets for FID

Pre-calculated statistics for FID can be downloaded here:

Folder structure:

./stats
├── celebahq.3k.128.npz
├── celebahq.all.256.npz
├── church.train.256.npz
├── cifar10.test.npz
├── cifar10.train.npz
└── stl10.unlabeled.48.npz

NOTE

All the reported values (Inception Score and FID) in our paper are calculated by official implementation instead of our implementation.

Training

How to integrate Gradient Normalization into your work?

The function normalize_gradient is implemented based on torch.autograd module, which can easily normalize your forward propagation of discriminator by updating a single line.

from torch.nn import BCEWithLogitsLoss
from models.gradnorm import normalize_gradient

net_D = ...     # discriminator
net_G = ...     # generator
loss_fn = BCEWithLogitsLoss()

# Update discriminator
x_real = ...                                    # real data
x_fake = net_G(torch.randn(64, 3, 32, 32))      # fake data
pred_real = normalize_gradient(net_D, x_real)   # net_D(x_real)
pred_fake = normalize_gradient(net_D, x_fake)   # net_D(x_fake)
loss_real = loss_fn(pred_real, torch.ones_like(pred_real))
loss_fake = loss_fn(pred_fake, torch.zeros_like(pred_fake))
(loss_real + loss_fake).backward()              # backward propagation
...

# Update generator
x_fake = net_G(torch.randn(64, 3, 32, 32))      # fake data
pred_fake = normalize_gradient(net_D, x_fake)   # net_D(x_fake)
loss_fake = loss_fn(pred_fake, torch.ones_like(pred_fake))
loss.backward()                                 # backward propagation
...

Citation

If you find our work is relevant to your research, please cite:

@InProceedings{GNGAN_2021_ICCV,
    author = {Yi-Lun Wu, Hong-Han Shuai, Zhi Rui Tam, Hong-Yu Chiu},
    title = {Gradient Normalization for Generative Adversarial Networks},
    booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
    month = {Oct},
    year = {2021}
}