Awesome
Diverse Structure Inpainting
Paper | Supplementary Material | ArXiv | BibTex
This repository is for the CVPR 2021 paper, "Generating Diverse Structure for Image Inpainting With Hierarchical VQ-VAE".
Introduction
<div align=center> <img src="./intro.png" width="50%" height="50%"> </div>(Top) Input incomplete image, where the missing region is depicted in gray. (Middle) Visualization of the generated diverse structures. (Bottom) Output images of our method.
Places2 Results
<div align=center> <img src="./places2_center.png">Results on the Places2 validation set using the center-mask Places2 model.
</div>CelebA-HQ Results
<div align=center> <img src="./celebahq_random.png">Results on one CelebA-HQ test image with different holes using the random-mask CelebA-HQ model.
</div>Installation
This code was tested with TensorFlow 1.12.0 (later versions may work, excluding 2.x), CUDA 9.0, Python 3.6 and Ubuntu 16.04
Clone this repository:
git clone https://github.com/USTC-JialunPeng/Diverse-Structure-Inpainting.git
Datasets
- CelebA-HQ: the high-resolution face images from Growing GANs. 24183 images for training, 2993 images for validation and 2824 images for testing.
- Places2: the challenge data from 365 scene categories. 8 Million images for training, 36K images for validation and 328K images for testing.
- ImageNet: the data from 1000 natural categories. 1 Million images for training and 50K images for validation.
Training
- Collect the dataset. For CelebA-HQ, we collect the 1024x1024 version. For Places2 and ImageNet, we collect the original version.
- Prepare the file list. Collect the path of each image and make a file, where each line is a path (end with a carriage return except the last line).
- Modify
checkpoints_dir
,dataset
,train_flist
andvalid_flist
arguments intrain_vqvae.py
,train_structure_generator.py
andtrain_texture_generator.py
. - Modify
data/data_loader.py
according to the dataset. For CelebA-HQ, we resize each image to 266x266 and randomly crop a 256x256. For Places2 and ImageNet, we randomly crop a 256x256 - Run
python train_vqvae.py
to train VQ-VAE. - Modify
vqvae_network_dir
argument intrain_structure_generator.py
andtrain_texture_generator.py
based on the path of pre-trained VQ-VAE. - Modify the mask setting arguments in
train_structure_generator.py
andtrain_texture_generator.py
to choose center mask or random mask. - Run
python train_structure_generator.py
to train the structure generator. - Run
python train_texture_generator.py
to train the texture generator. - Modify
structure_generator_dir
andtexture_generator_dir
arguments insave_full_model.py
based on the paths of pre-trained structure generator and texture generator. - Run
python save_full_model.py
to save the whole model.
Testing
- Collect the testing set. For CelebA-HQ, we resize each image to 256x256. For Places2 and ImageNet, we crop a center 256x256.
- Collect the corresponding mask set (2D grayscale, 0 indicates the known region, 255 indicates the missing region).
- Prepare the img file list and the mask file list as training. An example can be seen here.
- Modify
checkpoints_dir
,dataset
,img_flist
andmask_flist
arguments intest.py
. - Download the pre-trained model and put
model.ckpt.meta
,model.ckpt.index
,model.ckpt.data-00000-of-00001
andcheckpoint
undermodel_logs/
directory. - Run
python test.py
Pre-trained Models
Download the pre-trained models using the following links and put them under model_logs/
directory.
center_mask model
: CelebA-HQ_center | Places2_center | ImageNet_centerrandom_mask model
: CelebA-HQ_random | Places2_random | ImageNet_random
The center_mask models are trained with images of 256x256 resolution with center 128x128 holes. The random_mask models are trained with random regular and irregular holes.
Inference Time
One advantage of GAN-based and VAE-based methods is their fast inference speed. We measure that Mutual Encoder-Decoder with Feature Equalizations runs at 0.2 second per image on a single NVIDIA 1080 Ti GPU for images of resolution 256×256. In contrast, our model runs at 45 seconds per image. Naively sampling our autoregressive network is the major source of computational time. Fortunately, this time can be reduced by an order of magnitude using an incremental sampling technique which caches and reuses intermediate states of the network. Consider using this technique for faster inference.
Citing
If our method is useful for your research, please consider citing.
@inproceedings{peng2021generating,
title={Generating Diverse Structure for Image Inpainting With Hierarchical VQ-VAE},
author={Peng, Jialun and Liu, Dong and Xu, Songcen and Li, Houqiang},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
pages={10775-10784},
year={2021}
}