Home

Awesome

Few-shot Image Generation via Cross-domain Correspondence

Utkarsh Ojha, Yijun Li, Jingwan Lu, Alexei A. Efros, Yong Jae Lee, Eli Shechtman, Richard Zhang

Adobe Research, UC Davis, UC Berkeley

teaser

PyTorch implementation of adapting a source GAN (trained on a large dataset) to a target domain using very few images.

Project page | Paper

Overview

<img src='imgs/method_diagram.png' width="840px"/>

Our method helps adapt the source GAN where one-to-one correspondence is preserved between the source G<sub>s</sub>(z) and target G<sub>t</sub>(z) images.

Requirements

Note: The base model is taken from StyleGAN2's implementation from @rosinality

Testing

We provide the pre-trained models for different source and adapted (target) GAN models.

Source GAN: G<sub>s</sub>Target GAN: G<sub>s→t</sub>
FFHQ[Sketches] [Caricatures] [Amedeo Modigliani] [Babies] [Sunglasses] [Rafael] [Otto Dix]
LSUN Church[Haunted houses] [Van Gogh houses [Landscapes] [Caricatures]
LSUN Cars[Wrecked cars] [Landscapes] [Haunted houses] [Caricatures]
LSUN Horses[Landscapes] [Caricatures] [Haunted houses]
Hand gestures[Google Maps] [Landscapes]

For now, we have only included the pre-trained models using FFHQ as the source domain, i.e. all the models in the first row. We will add the remaining ones soon.

Download the pre-trained model(s), and store it into ./checkpoints directory.

Sample images from a model

To generate images from a pre-trained GAN, run the following command:

CUDA_VISIBLE_DEVICES=0 python generate.py --ckpt_target /path/to/model/

Here, model_name follows the notation of source_target, e.g. ffhq_sketches. Use the --load_noise option to use the noise vectors used for some figures in the paper (Figures 1-4). For example:

CUDA_VISIBLE_DEVICES=0 python generate.py --ckpt_target ./checkpoints/ffhq_sketches.pt --load_noise noise.pt

This will save the images in the test_samples/ directory.

Visualizing correspondence results

To visualize the same noise in the source and adapted models, i.e. G<sub>s</sub>(z) and G<sub>s→t</sub>(z), run the following command(s):

# generate two image grids of 5x5 for source and target
CUDA_VISIBLE_DEVICES=0 python generate.py --ckpt_source /path/to/source --ckpt_target /path/to/target --load_noise noise.pt

# visualize the interpolations of source and target
CUDA_VISIBLE_DEVICES=0 python generate.py --ckpt_source /path/to/source --ckpt_target /path/to/source --load_noise noise.pt --mode interpolate
python traversal_gif.py 10

Hand gesture experiments

We collected images of random hand gestures being performed on a plain surface (~ 18k images), and used that as the data to train a source model (from scratch). We then adapted it to two different target domains; Landscape images and Google maps. The goal was to see if, during inference, interpolating the hand genstures can result in meaningful variations in the target images. Run the following commands to see the results:

CUDA_VISIBLE_DEVICES=0 python generate.py --ckpt_source /path/to/source --ckpt_target /path/to/maps(landscapes) --load_noise noise.pt --mode interpolate

Evaluating FID

The following table provides a link to the test set of domains used in Table 1:

Download, and unzip the set of images into your desired directory, and compute the FID score (taken from pytorch-fid) between the real (R<sub>test</sub>) and fake (F) images, by running the following command

python -m pytorch_fid /path/to/real/images /path/to/fake/images

Evaluating intra-cluster distance

Download the entire set of images from this link (1.1 GB), which are used for the results in Table 2. The organization of this collection is as follows:

cluster_centers
└── amedeo			# target domain -- will be from [amedeo, sketches]
    └── ours			# baseline -- will be from [tgan, tgan_ada, freezeD, ewc, ours]
        └── c0			# center id -- there will be 10 clusters [c0, c1 ... c9]
            ├── center.png	# cluster center -- this is one of the 10 training images used. Each cluster will have its own center
            │── img0.png   	# generated images which matched with this cluster's center, according to LPIPS metric.
            │── img1.png
            │      .
	    │      .
                   

Unzip the file, and then run the following command to compute the results for a baseline on a dataset:

CUDA_VISIBLE_DEVICES=0 python feat_cluster.py --baseline <baseline> --dataset <target_domain> --mode intra_cluster_dist

E.g.
CUDA_VISIBLE_DEVICES=0 python feat_cluster.py --baseline tgan --dataset sketches --mode intra_cluster_dist

We also provide the utility to visualize the closest and farthest members of a cluster, as shown in Figure 14 (shown below), using the following command:

CUDA_VISIBLE_DEVICES=0 python feat_cluster.py --baseline tgan --dataset sketches --mode visualize_members

The command will save the generated image which is closest/farthest to/from a center as closest.png/farthest.png respectively.

<img src='imgs/cluster_members.png' width="840px"/>

Training (adapting) your own GAN

Choose the source domain

Choose the target domain

SketchesAmedeo ModiglianiBabiesSunglassesRafaelOtto DixHaunted housesVan Gogh housesLandscapesWrecked carsMaps
imagesimagesimagesimagesimagesimagesimagesimagesimagesimagesimages
processedprocessedprocessedprocessedprocessedprocessedprocessedprocessedprocessedprocessedprocessed

Note We cannot share the images for the caricature domain due to license issues.

CUDA_VISIBLE_DEVICES=0 python train.py --ckpt_source /path/to/source_model --data_path /path/to/target_data --exp <exp_name>

# sample run
CUDA_VISIBLE_DEVICES=0 python train.py --ckpt_source ./checkpoints/source_ffhq.pt --data_path ./processed_data/sketches --exp ffhq_to_sketches    

This will create directories with name ffhq_to_sketches in ./checkpoints/ (saving the intermediate models) and in ./samples (saving the intermediate generated images).

Runnig the above code with default configurations, i.e. batch size = 4, will use ~20 GB GPU memory.

Bibtex

If you find our code useful, please cite our paper:

@inproceedings{ojha2021few-shot-gan,
  title={Few-shot Image Generation via Cross-domain Correspondence},
  author={Ojha, Utkarsh and Li, Yijun and Lu, Cynthia and Efros, Alexei A. and Lee, Yong Jae and Shechtman, Eli and Zhang, Richard},
  booktitle={CVPR},
  year={2021}
}

Acknowledgment

As mentioned before, the StyleGAN2 model is borrowed from this wonderful pytorch implementation by @rosinality. We are also thankful to @mseitzer and @richzhang for their user friendly implementations of computing FID score and LPIPS metric respectively.