Awesome
Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding
by Qiaole Dong*, Chenjie Cao*, Yanwei Fu
Paper and Supplemental Material (arXiv)
Our project page is available at https://dqiaole.github.io/ZITS_inpainting/.
🔥🔥🔥 News: Our Extended version ZITS++ has been accepted by TPAMI, codes and dataset have been released in here.
Pipeline
The overview of our ZITS. At first, the TSR model is used to restore structures with low resolutions. Then the simple CNN based upsampler is leveraged to upsample edge and line maps. Moreover, the upsampled sketch space is encoded and added to the FTR through ZeroRA to restore the textures.
TO DO
- Releasing inference codes.
- Releasing pre-trained model.
- Releasing training codes.
Preparation
-
Preparing the environment:
as there are some bugs when using GP loss with DDP (link), we strongly recommend installing Apex without CUDA extensions via torch1.9.0 for the multi-gpu training
conda create -n train_env python=3.6 conda activate train_env pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html pip install -r requirement.txt git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --no-build-isolation ./
-
For training, MST provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training as in MST.
The training masks we used are contained in coco_mask_list.txt and irregular_mask_list.txt, besides test_mask.zip includes 1000 test masks.
-
Download the pretrained masked wireframe detection model to the './ckpt' fold: LSM-HAWP (MST ICCV2021 retrained from HAWP CVPR2020).
-
Prepare the wireframes:
<!-- as the MST train the LSM-HAWP in Pytorch 1.3.1 and it causes problem ([link](https://github.com/cherubicXN/hawp/issues/31)) when tested in Pytorch 1.9, we recommand to inference the lines(wireframes) with torch==1.3.1. If the line detection is not based on torch1.3.1, the performance may drop a little. ``` conda create -n wireframes_inference_env python=3.6 conda activate wireframes_inference_env pip install torch==1.3.1 torchvision==0.4.2 pip install -r requirement.txt ``` -->Update: No need prepare another environment anymore, just extract wireframes with following code
conda activate train_env python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path> --gpu_ids '0'
-
If you need to train the model, please download the pretrained models for perceptual loss, provided by LaMa:
mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/ wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
-
Indoor Dataset and Test set of Places2 (Optional)
To download the full Indoor dataset: BaiduDrive, passward:hfok; Google drive (link).
The training and validation split of Indoor can be find on indoor_train_list.txt and indoor_val_list.txt.
The test set of our Places2 can be find on places2_test_list.txt.
Eval
Download pretrained models on Places2 here.
Link for BaiduDrive, password:qnm5
Batch Test
For batch test, you need to complete steps 3 and 4 above.
Put the pretrained models to the './ckpt' fold. Then modify the config file according to you image, mask and wireframes path.
Test on 256 images:
conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2 --config_file ./config_list/config_ZITS_places2.yml --GPU_ids '0'
Test on 512 images:
conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0'
Single Image Test
This code only supports squared images (or they will be center cropped).
conda activate train_env
python single_image_test.py --path <ckpt_path> --config_file <config_path> \
--GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./
Training
:warning: Warning: The training codes is not fully tested yet after refactoring
Training TSR
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
--train_line_path [training_wireframes_path] \
--mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
--train_epoch 12 --validation_path [validation_data_path] \
--val_line_path [validation_wireframes_path] \
--valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
--train_line_path [training_wireframes_path] \
--mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
--train_epoch 15 --validation_path [validation_data_path] \
--val_line_path [validation_wireframes_path] \
--valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP --MaP
Train SSU
We recommend to use the pretrained SSU. You can also train your SSU refered to https://github.com/ewrfcas/StructureUpsampling.
Training LaMa First
python FTR_train.py --nodes 1 --gpus 1 --GPU_ids '0' --path ./ckpt/lama_places2 \
--config_file ./config_list/config_LAMA.yml --lama
Training FTR
256:
python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2 \
--config_file ./config_list/config_ZITS_places2.yml --DDP
256~512:
python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2_HR \
--config_file ./config_list/config_ZITS_HR_places2.yml --DDP
More 1K Results
Acknowledgments
Cite
If you found our program helpful, please consider citing:
@inproceedings{dong2022incremental,
title={Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding},
author={Qiaole Dong and Chenjie Cao and Yanwei Fu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}