Awesome
SemanticStyleGAN: Learning Compositional Generative Priors for Controllable Image Synthesis and Editing (CVPR 2022)
<a href="https://arxiv.org/abs/2112.02236"><img src="https://img.shields.io/badge/arXiv-2112.02236-b31b1b" height=22.5></a> <a href="https://creativecommons.org/licenses/by/4.0"><img src="https://img.shields.io/badge/LICENSE-CC--BY--4.0-yellow" height=22.5></a> <a href="https://www.youtube.com/watch?v=nfKiVX4pFlw"><img src="https://img.shields.io/static/v1?label=CVPR 2022&message=5 Minute Video&color=red" height=22.5></a>
Yichun Shi, Xiao Yang, Yangyue Wan, Xiaohui Shen
Recent studies have shown that StyleGANs provide promising prior models for downstream tasks on image synthesis and editing. However, since the latent codes of StyleGANs are designed to control global styles, it is hard to achieve a fine-grained control over synthesized images. We present SemanticStyleGAN, where a generator is trained to model local semantic parts separately and synthesizes images in a compositional way. The structure and texture of different local parts are controlled by corresponding latent codes. Experimental results demonstrate that our model provides a strong disentanglement between different spatial areas. When combined with editing methods designed for StyleGANs, it can achieve a more fine-grained control to edit synthesized or real images. The model can also be extended to other domains via transfer learning. Thus, as a generic prior model with built-in disentanglement, it could facilitate the development of GAN-based applications and enable more potential downstream tasks.
<a href="https://www.youtube.com/watch?v=nfKiVX4pFlw"><img src="docs/abstract_idea.png" width="550px"/></a>
Description
Official Implementation of our SemanticStyleGAN paper for training and inference.
Table of Contents
Installation
- Python 3
- Pytorch 1.8+
- Run
pip install -r requirements.txt
to install additional dependencies.
Pretrained Models
In this repository, we provide pretrained models for various domains.
Path | Description |
---|---|
CelebAMask-HQ | Trained on the CelebAMask-HQ dataset. |
BitMoji | Fine-tuned on the re-cropped BitMoji dataset. |
MetFaces | Fine-tuned on the MetFaces dataset. |
Toonify | Fine-tuned on the Toonify dataset. |
Inference
Synthesis
Random Synthesis
In visualize/generate.py
, we provide a script for sampling random images and their corresponding segmentation masks with SemanticStyleGAN.
An example command is provided below:
python visualize/generate.py \
pretrained/CelebAMask-HQ-512x512.pt \
--outdir results/samples \
--sample 20 \
--save_latent
The --save_latent
flag will save the w latent code of each synthesized image in a separate .npy
file.
Local Latent Interpolation
<img src="https://semanticstylegan.github.io/images/traverse.gif" width="450"/>In visualize/generate_video.py
, we provide a script for visualizing the local interpolation by SemanticStyleGAN.
An example command is provided below:
python visualize/generate_video.py \
pretrained/CelebAMask-HQ-512x512.pt \
--outdir results/interpolation \
--latent results/samples/000000_latent.npy
Here, /results/samples/000000_latent.npy
is the latent code either generated by visualize/generate.py
or output by visualize/invert.py
. You can also ignore the --latent
argument for
generating a video with a random latent code. The scripts will create several mp4 files under the output folder, each shows the interpolation animation in
a specific latent subspace.
Synthesizing Components
<img src="docs/components.jpg" width="800"/>In visualize/generate_components.py
, we provide a script for visualizing the components synthesized by SemanticStyleGAN, where we gradually add more local generators
into the synthesis procedure.
An example command is provided below:
python visualize/generate_components.py \
pretrained/CelebAMask-HQ-512x512.pt \
--outdir results/components \
--latent results/samples/000000_latent.npy
You can also ignore the --latent
argument for generating components for a random latent code.
Inversion
Optimization-based
You can use visualize/invert.py
for inverting real images into the latent space of SemanticStyleGAN via optimization:
python visualize/invert.py \
--ckpt pretrained/CelebAMask-HQ-512x512.pt \
--imgdir data/examples \
--outdir results/inversion \
--size 512
This script will save the reconstructed images and their corresponding w-plus latent codes in separate sub-directories under the outdir
. Additionally, you can set --finetune_step
to a non-zero integer (e.g. 300) for pivotal tuning inversion, which outputs a new fine-tuned generator for each image.
You can manipulate the reconstructed faces by using the saved latent codes. You can also choose to edit the face with a fine-tuned generator from PTI or domain adaptation. An example command is provided below:
python visualize/generate_video.py \
pretrained/BitMoji-512x512.pt \
--outdir results/interpolation_inversion \
--latent results/inversion/latent/1.npy
Here is an example result of changing the inverted latent code of eyes using the BitMoji generator:
<img src="docs/inversion_bitmoji.gif" width="500"/>Computing Metrics
Given a trained generator and a prepared inception file, we can compute the metrics with following command:
python calc_fid.py \
--ckpt /path/to/checkpoint \
--inception /path/to/inception/file
<br>
Training
Data Preparation
- In our work, we use re-mapped segmentation labels of CelebAMask-HQ. To reproduce this dataset, first download the original CelebAMask-HQ dataset from here and decompress it to
data/CelebAMask-HQ
. Then, run the following command to create the images and labels used for training:
python data/preprocess_celeba.py data/CelebAMask-HQ
The script will create four folders under the data/CelebAMask-HQ
that contain the images and labels for training and testing, respectively.
- Similar to rosinality's implementation of StyleGAN2, we use LMDB datasets for training. An example command is provided below:
python prepare_mask_data.py
data/CelebAMask-HQ/image_train \
data/CelebAMask-HQ/label_train \
--out data/lmdb_celebamaskhq_512 \
--size 512
You can also use your own dataset for the step. Note that the mask labels and image files are matched according to file names. It is okay if the files are under sub-directories. But make sure the base names are unique.
- Prepare the inception file for calculating FID:
python prepare_inception.py
data/lmdb_celebamaskhq_512
--output data/inception_celebamaskhq_512.pkl \
--size 512
--dataset_type mask
Training SemanticStyleGAN
The main training script can be found in train.py
. Here, we provide an example for training on the CelebAMask-HQ that we prepared as above :
python train.py \
--dataset data/lmdb_celebamaskhq_512 \
--inception data/inception_celebamaskhq_512.pkl \
--checkpoint_dir checkpoint/celebamaskhq_512 \
--seg_dim 13 \
--size 512 \
--transparent_dims 10 12 \
--residual_refine \
--batch 16 \
or you can use the following command for multi-gpu training (we assume 8 gpus are available):
python -m torch.distributed.launch --nproc_per_node=8 \
train.py \
--dataset data/lmdb_celebamaskhq_512 \
--inception data/inception_celebamaskhq_512.pkl \
--checkpoint_dir checkpoint/celebamaskhq_512 \
--seg_dim 13 \
--size 512 \
--transparent_dims 10 12 \
--residual_refine \
--batch 4
Here, --seg_dim
refers to the number of segmentation classes (including background). --transparent_dims
specifies the classes that are treated to be possibly transparent.
If you want to restore from an intermediate checkpoint, simply add the argument --ckpt /path/to/chekcpoint/file
where the checkpoint file is a .pt file saved by our training script.
Additionally, if you have tensorboard installed, you can visualize tensorboard logs in the checkpoint_dir
.
Domain Adaptation
In train_adaptation.py
, we provide a script for performing domain adaptation on image-only datasets.
To do this, you first need to create an LMDB for the target image dataset.
A example command is provided below:
python prepare_image_data.py \
data/metfaces/images \
--size 512 \
--out data/lmdb_metfaces_512
Then, you can run the following command for fine-tuning on the target dataset:
python -m torch.distributed.launch --nproc_per_node=8 \
train_adaptation.py \
--ckpt pretrained/CelebAMask-HQ-512x512.pt \
--dataset data/lmdb_metfaces_512 \
--checkpoint_dir checkpoint/metfaces \
--seg_dim 13 \
--size 512 \
--transparent_dims 10 12 \
--residual_refine \
--batch 4 \
--freeze_local
The --freeze_local
flag will freeze the local generators during training, which preserves the spatial disentanglement. However, for datasets that has a large geometric difference from the real faces (e.g. BitMoji), you may want to remove this argument. In fact, we found that our model is still able to preserve the disentanglement within a few thousand steps of fine-tuning all modules.
Note that the dataloader for adaptation is compatiable with rosinality's implementation, so you can use the same LMDB datasets for fine-tuning SemanticStyleGAN. By default we fine-tune the model for 2000 steps, but you may want to look at the visualization samples for early stopping.
<br>Credits
StyleGAN2 model and implementation:
https://github.com/rosinality/stylegan2-pytorch
Copyright (c) 2019 Kim Seonghyeon
License (MIT) https://github.com/rosinality/stylegan2-pytorch/blob/master/LICENSE
LPIPS model and implementation:
https://github.com/S-aiueo32/lpips-pytorch
Copyright (c) 2020, Sou Uchida
License (BSD 2-Clause) https://github.com/S-aiueo32/lpips-pytorch/blob/master/LICENSE
ReStyle model and implementation:
https://github.com/yuval-alaluf/restyle-encoder
Copyright (c) 2021 Yuval Alaluf
License (MIT) https://github.com/yuval-alaluf/restyle-encoder/blob/main/LICENSE
Please Note: The CUDA files are made available under the Nvidia Source Code License-NC
Acknowledgments
This code is initialy built from SemanticGAN.
Citation
If you use this code for your research, please cite the following work:
@inproceedings{shi2021SemanticStyleGAN,
author = {Shi, Yichun and Yang, Xiao and Wan, Yangyue and Shen, Xiaohui},
title = {SemanticStyleGAN: Learning Compositional Generative Priors for Controllable Image Synthesis and Editing},
booktitle = {CVPR},
year = {2022},
}