Awesome
Cross Aggregation Transformer for Image Restoration
Zheng Chen, Yulun Zhang, Jinjin Gu, Yongbing Zhang, Linghe Kong, and Xin Yuan, "Cross Aggregation Transformer for Image Restoration", NeurIPS, 2022 (Spotlight)
[paper] [arXiv] [supplementary material] [visual results] [pretrained models]
Abstract: Recently, Transformer architecture has been introduced into image restoration to replace convolution neural network (CNN) with surprising results. Considering the high computational complexity of Transformer with global attention, some methods use the local square window to limit the scope of self-attention. However, these methods lack direct interaction among different windows, which limits the establishment of long-range dependencies. To address the above issue, we propose a new image restoration model, Cross Aggregation Transformer (CAT). The core of our CAT is the Rectangle-Window Self-Attention (Rwin-SA), which utilizes horizontal and vertical rectangle window attention in different heads parallelly to expand the attention area and aggregate the features cross different windows. We also introduce the Axial-Shift operation for different window interactions. Furthermore, we propose the Locality Complementary Module to complement the self-attention mechanism, which incorporates the inductive bias of CNN (e.g., translation invariance and locality) into Transformer, enabling global-local coupling. Extensive experiments demonstrate that our CAT outperforms recent state-of-the-art methods on several image restoration applications.
SR (x4) | HQ | LQ | SwinIR | CAT (ours) |
---|---|---|---|---|
<img src="figs/img_024_x4.png" height=80 width=110/> | <img src="figs/img_024_HR_x4.png" height=80/> | <img src="figs/img_024_Bicubic_x4.png" height=80/> | <img src="figs/img_024_SwinIR_x4.png" height=80/> | <img src="figs/img_024_CAT_x4.png" height=80/> |
<img src="figs/img_074_x4.png" height=80 width=110/> | <img src="figs/img_074_HR_x4.png" height=80/> | <img src="figs/img_074_Bicubic_x4.png" height=80/> | <img src="figs/img_074_SwinIR_x4.png" height=80/> | <img src="figs/img_074_CAT_x4.png" height=80/> |
Dependencies
- Python 3.8
- PyTorch 1.8.0
- NVIDIA GPU + CUDA
# Clone the github repo and go to the default directory 'CAT'.
git clone https://github.com/zhengchen1999/CAT.git
conda create -n CAT python=3.8
conda activate CAT
pip install -r requirements.txt
python setup.py develop
TODO
- Image SR
- JPEG Compression Artifact Reduction
- Image Denoising
- Other tasks
Contents
Datasets
Used training and testing sets can be downloaded as follows:
Task | Training Set | Testing Set | Visual Results |
---|---|---|---|
image SR | DIV2K (800 training images, 100 validation images) + Flickr2K (2650 images) [complete training dataset DF2K] | Set5 + Set14 + BSD100 + Urban100 + Manga109 [complete testing dataset download] | here |
grayscale JPEG compression artifact reduction | DIV2K (800 training images) + Flickr2K (2650 images) + WED(4744 images) + BSD500 (400 training&testing images) [complete training dataset DFWB] | Classic5 +LIVE + Urban100 [complete testing dataset download] | here |
real image denoising | SIDD (320 training images) [complete training dataset SIDD] | SIDD + DND [complete testing dataset download] | here |
Here the visual results are generated under SR (x4), JPEG compression artifact reduction (q10), and real image denoising.
Download training and testing datasets and put them into the corresponding folders of datasets/
and restormer/datasets
. See datasets for the detail of directory structure.
Models
Task | Method | Params (M) | FLOPs (G) | Dataset | PSNR (dB) | SSIM | Model Zoo | Visual Results |
---|---|---|---|---|---|---|---|---|
SR | CAT-R | 16.60 | 292.7 | Urban100 | 27.45 | 0.8254 | Google Drive | Google Drive |
SR | CAT-A | 16.60 | 360.7 | Urban100 | 27.89 | 0.8339 | Google Drive | Google Drive |
SR | CAT-R-2 | 11.93 | 216.3 | Urban100 | 27.59 | 0.8285 | Google Drive | Google Drive |
SR | CAT-A-2 | 16.60 | 387.9 | Urban100 | 27.99 | 0.8357 | Google Drive | Google Drive |
CAR | CAT | 16.20 | 346.4 | LIVE1 | 29.89 | 0.8295 | Google Drive | Google Drive |
real-DN | CAT | 25.77 | 53.2 | SIDD | 40.01 | 0.9600 | Google Drive | Google Drive |
The performance is reported on Urban100 (x4, SR), LIVE1 (q=10, CAR), and SIDD (real-DN). The test input size of FLOPs is 128 x 128.
Training
Image SR
-
Cd to 'CAT' and run the setup script.
# If already in CAT and set up, please ignore python setup.py develop
-
Download training (DF2K, already processed) and testing (Set5, Set14, BSD100, Urban100, Manga109, already processed) datasets, place them in
datasets/
. -
Run the following scripts. The training configuration is in
options/train/
.# CAT-R, SR, input=64x64, 4 GPUs python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_sr_x2.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_sr_x3.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_sr_x4.yml --launcher pytorch # CAT-A, SR, input=64x64, 4 GPUs python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_sr_x2.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_sr_x3.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_sr_x4.yml --launcher pytorch # CAT-R-2, SR, input=64x64, 4 GPUs python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_2_sr_x2.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_2_sr_x3.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_R_2_sr_x4.yml --launcher pytorch # CAT-A-2, SR, input=64x64, 4 GPUs python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_2_sr_x2.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_2_sr_x3.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_A_2_sr_x4.yml --launcher pytorch
-
The training experiment is in
experiments/
.
JPEG Compression Artifact Reduction
-
Cd to 'CAT' and run the setup script
# If already in CAT and set up, please ignore python setup.py develop
-
Download training (DFWB, already processed) and testing (Classic5, LIVE1, Urban100, already processed) datasets, place them in
datasets/
. -
Run the following scripts. The training configuration is in
options/train/
.# CAT, CAR, input=128x128, 4 GPUs python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_car_q10.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_car_q20.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_car_q30.yml --launcher pytorch python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_CAT_car_q40.yml --launcher pytorch
-
The training experiment is in
experiments/
.
Real Image Denoising
-
Cd to 'CAT/restormer' and run the setup script
# If already in restormer and set up, please ignore python setup.py develop --no_cuda_ext
-
Download training (SIDD-train, contains validation dataset, already processed) datasets, and place them in
datasets/
(restormer/datasets/
). -
Run the following scripts. The training configuration is in
options/
(restormer/options/
).# CAT, Real DN, Progressive Learning, 8 GPUs python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train_RealDenoising_CAT.yml --launcher pytorch
-
The training experiment is in
experiments/
(restormer/experiments/
).
Testing
Image SR
-
Cd to 'CAT' and run the setup script
# If already in CAT and set up, please ignore python setup.py develop
-
Download the pre-trained models and place them in
experiments/pretrained_models/
.We provide pre-trained models for image SR: CAT-R, CAT-A, CAT-A, and CAT-R-2 (x2, x3, x4).
-
Download testing (Set5, Set14, BSD100, Urban100, Manga109, already processed) datasets, place them in
datasets/
. -
Run the following scripts. The testing configuration is in
options/test/
(e.g., test_CAT_R_sr_x2.yml).Note 1: You can set
use_chop: True
(default: False) in YML to chop the image for testing.# No self-ensemble # CAT-R, SR, reproduces results in Table 2 of the main paper python basicsr/test.py -opt options/test/test_CAT_R_sr_x2.yml python basicsr/test.py -opt options/test/test_CAT_R_sr_x3.yml python basicsr/test.py -opt options/test/test_CAT_R_sr_x4.yml # CAT-A, SR, reproduces results in Table 2 of the main paper python basicsr/test.py -opt options/test/test_CAT_A_sr_x2.yml python basicsr/test.py -opt options/test/test_CAT_A_sr_x3.yml python basicsr/test.py -opt options/test/test_CAT_A_sr_x4.yml # CAT-R-2, SR, reproduces results in Table 1 of the supplementary material python basicsr/test.py -opt options/test/test_CAT_R_2_sr_x2.yml python basicsr/test.py -opt options/test/test_CAT_R_2_sr_x3.yml python basicsr/test.py -opt options/test/test_CAT_R_2_sr_x4.yml # CAT-A-2, SR, reproduces results in Table 1 of the supplementary material python basicsr/test.py -opt options/test/test_CAT_A_2_sr_x2.yml python basicsr/test.py -opt options/test/test_CAT_A_2_sr_x3.yml python basicsr/test.py -opt options/test/test_CAT_A_2_sr_x4.yml
-
The output is in
results/
.
JPEG Compression Artifact Reduction
-
Cd to 'CAT' and run the setup script
# If already in CAT and set up, please ignore python setup.py develop
-
Download the pre-trained models and place them in
experiments/pretrained_models/
.We provide pre-trained models for JPEG compression artifact reduction: CAT (q10, q20, q30, q40).
-
Download testing (Classic5, LIVE, Urban100, already processed) datasets, place them in
datasets/
. -
Run the following scripts. The testing configuration is in
options/test/
(e.g., test_CAT_car_q10.yml).# No self-ensemble # CAT-A, CAR, rereproduces results in Table 3 of the main paper python basicsr/test.py -opt options/test/test_CAT_car_q10.yml python basicsr/test.py -opt options/test/test_CAT_car_q20.yml python basicsr/test.py -opt options/test/test_CAT_car_q30.yml python basicsr/test.py -opt options/test/test_CAT_car_q40.yml
-
The output is in
results/
.
Real Image Denoising
-
Cd to 'CAT' and run the setup script
# If already in CAT and set up, please ignore python setup.py develop
-
Download the pre-trained models and place them in
experiments/pretrained_models/
. -
Download testing (SIDD, DND) datasets, place them in
datasets/
. -
Run the following scripts. The testing configuration is in
options/test/
.# No self-ensemble # CAT, real DN, reproduces results in Table 4 of the main paper # testing on SIDD python test_real_denoising_sidd.py --save_images evaluate_sidd.m # testing on DND python test_real_denoising_dnd.py --save_images
-
The output is in
results/
.
Results
We achieve state-of-the-art performance on image SR, JPEG compression artifact reduction and real image denoising. Detailed results can be found in the paper. All visual results of CAT can be downloaded here.
<details> <summary>Image SR (click to expand)</summary>- results in Table 2 of the main paper
- results in Table 1 of the supplementary material
- visual comparison (x4) in the main paper
- visual comparison (x4) in the supplementary material
- results in Table 3 of the main paper
- results in Table 3 of the supplementary material (test on Urban100)
- visual comparison (q=10) in the main paper
- visual comparison (q=10) in the supplementary material
- results in Table 4 of the main paper
*: We re-test the SIDD with all official pre-trained models.
</details>Citation
If you find the code helpful in your research or work, please cite the following paper(s).
@inproceedings{chen2022cross,
title={Cross Aggregation Transformer for Image Restoration},
author={Chen, Zheng and Zhang, Yulun and Gu, Jinjin and Zhang, Yongbing and Kong, Linghe and Yuan, Xin},
booktitle={NeurIPS},
year={2022}
}