Home

Awesome

Vision-aided GAN

PWC PWC PWC

video | website | paper

[NEW!] Vision-aided GAN training with BigGAN and StyleGAN3

[NEW!] Using vision-aided Discriminator in your own GAN training with pip install vision-aided-loss

<img src='docs/code.gif' align="center" width=800> <br> <div class="gif"> <p align="center"> <img src='docs/vision-aided-gan.gif' align="center" width=800> </p> </div>

Can the collective knowledge from a large bank of pretrained vision models be leveraged to improve GAN training? If so, with so many models to choose from, which one(s) should be selected, and in what manner are they most effective?

We find that pretrained computer vision models can significantly improve performance when used in an ensemble of discriminators. We propose an effective selection mechanism, by probing the linear separability between real and fake samples in pretrained model embeddings, choosing the most accurate model, and progressively adding it to the discriminator ensemble. Our method can improve GAN training in both limited data and large-scale settings.

Ensembling Off-the-shelf Models for GAN Training <br> Nupur Kumari, Richard Zhang, Eli Shechtman, Jun-Yan Zhu<br> In CVPR 2022

Quantitative Comparison

<p align="center"> <img src="docs/lsun_eval.jpg" width="800px"/><br> </p>

Our method outperforms recent GAN training methods by a large margin, especially in limited sample setting. For LSUN Cat, we achieve similar FID as StyleGAN2 trained on the full dataset using only 0.7% of the dataset. On the full dataset, our method improves FID by 1.5x to 2x on cat, church, and horse categories of LSUN.

Example Results

Below, we show visual comparisons between the baseline StyleGAN2-ADA and our model (Vision-aided GAN) for the same randomly sample latent code on 100-shot Bridge-of-sighs and AnimalFace Dog dataset.

<img src="docs/bridge.gif" width="400px"/><img src="docs/animalface_dog.gif" width="400px"/>

Interpolation Videos

Latent interpolation results of models trained with our method on AnimalFace Cat (160 images), Dog (389 images), and Bridge-of-Sighs (100 photos).

<p align="center"> <img src="docs/interp.gif" width="800px"/> </p>

Worst sample visualzation

We randomly sample 5k images and sort them according to Mahalanobis distance using mean and variance of real samples calculated in inception feature space. Below visualization shows the bottom 30 images according to the distance for StyleGAN2-ADA (left) and our model (right).

<details open><summary>AFHQ Dog</summary> <p> <div class="images"> <table width=500> <tr> <td valign="top"><img src="docs/afhqdog_worst_baseline.jpg"/></td> <td valign="top"><img src="docs/afhqdog_worst_ours.jpg"/></td> </tr> </table> </div> </p> </details> <details><summary>AFHQ Cat</summary> <p> <div class="images"> <table> <tr> <td valign="top"><img src="docs/afhqcat_worst_baseline.jpg"/></td> <td valign="top"><img src="docs/afhqcat_worst_ours.jpg"/></td> </tr> </table> </div> </p> </details> <details><summary>AFHQ Wild</summary> <p> <div class="images"> <table> <tr> <td valign="top"><img src="docs/afhqwild_worst_baseline.jpg"/></td> <td valign="top"><img src="docs/afhqwild_worst_ours.jpg"/></td> </tr> </table> </div> </p> </details>

Pretrained Models

StyleGAN2 models

BigGAN models

All pre-trained models can be downloaded at this link as well.

Vision-aided StyleGAN2 training

Please see stylegan2 README for training StyleGAN2 models with our method. This code will reproduce all StyleGAN2 based results from our paper.

Vision-aided Discriminator in a custom GAN model

install the library via pip install git+https://github.com/nupurkmr9/vision-aided-gan.git or pip install vision-aided-loss

For details on off-the-shelf models please see MODELS.md


import vision_aided_loss

device='cuda'
discr = vision_aided_loss.Discriminator(cv_type='clip', loss_type='multilevel_sigmoid_s', device=device).to(device)
discr.cv_ensemble.requires_grad_(False) # Freeze feature extractor

# Sample images
real = sample_real_image()
fake = G.forward(z)

# Update discriminator discr
lossD = discr(real, for_real=True) + discr(fake, for_real=False)
lossD.backward()

# Update generator G
lossG = discr(fake, for_G=True)
lossG.backward()

# We recommend adding vision-aided adversarial loss after training GAN with standard loss till few warmup_iter.

Arg details:

Vision-aided StyleGAN3 training

Please see stylegan3 README for training StyleGAN3 models with our method.

Vision-aided BigGAN training

Please see biggan README for training BigGAN models with our method.

To add you own pretrained Model

create the class file to extract pretrained features as vision_module/<custom_model>.py. Add the class path in the class_name_dict in vision_module.cvmodel.CVBackbone class. Update the architecture of trainable classifier head over pretrained features in vision_module.cv_discriminator. Reinstall library manually via pip install .

References

@InProceedings{kumari2021ensembling,
  title={Ensembling Off-the-shelf Models for GAN Training},
  author={Kumari, Nupur and Zhang, Richard and Shechtman, Eli and Zhu, Jun-Yan},
  booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  month     = {June},
  year      = {2022}
}

Acknowledgments

We thank Muyang Li, Sheng-Yu Wang, Chonghyuk (Andrew) Song for proofreading the draft. We are also grateful to Alexei A. Efros, Sheng-Yu Wang, Taesung Park, and William Peebles for helpful comments and discussion. Our codebase is built on stylegan2-ada-pytorch and DiffAugment.