Home

Awesome

SAM-DiffSR: Structure-Modulated Diffusion Model for Image Super-Resolution

This is the official implementation of the paper "SAM-DiffSR: Structure-Modulated Diffusion Model for Image Super-Resolution".

<div align=center> <img src="./README.assets/image-20240218110908900.png" alt="image-20240218110908900" width="80%" /> </div>

Abstract

Conventional diffusion models perform noise sampling from a single distribution, constraining their ability to handle real-world scenes and complex textures across semantic regions. With the success of segment anything model (SAM), generating sufficiently fine-grained region masks can enhance the detail recovery of diffusion-based SR model. However, directly integrating SAM into SR models will result in much higher computational cost. We propose the SAM-DiffSR model, which can utilize the fine-grained structure information from SAM in the process of sampling noise to improve the image quality without additional computational cost during inference. In the process of training, we encode structural position information into the segmentation mask from SAM. Then the encoded mask is integrated into the forward diffusion process by modulating it to the sampled noise. This adjustment allows us to independently adapt the noise mean within each corresponding segmentation area. The diffusion model is trained to estimate this modulated noise. Crucially, our proposed framework does NOT change the reverse diffusion process and does NOT require SAM at inference.

Result

SamDiffSet5Set14Urban100BSDS100Manga109General100DIV2K
PSNR30.9927.1425.5426.4729.4330.3029.34
SSIM0.87310.74840.77210.70030.88990.83530.8109
FID48.2049.844.527660.812.399438.420.3809
<div align=center> <img src="README.assets/vis-1.png" width="80%" /> </div> <div align=center> <img src="README.assets/vis-2.png" width="80%" /> </div>

Data and Checkpoint

infolink
segment mask data in RLE format generate by samCowTransfer or Google Driver
embedde d mask in npy format generate by SPECowTransfer or Google Driver
model checkpointCowTransfer or Google Driver

Environment Installation

pip install -r requirements.txt

Dataset Preparation

Training dataset

  1. To download DF2K and DIV2K validation

    Make the data tree like this

    data/sr
    ├── DF2K
    │   └── DF2K_train_HR
    │       ├── 0002.png
    │       ├── 0003.png
    │       ├── 0001.png
    |       ├── ...
    └── DIV2K
        └── DIV2K_valid_HR
            ├── 0002.png
            ├── 0003.png
            ├── 0001.png
            ├── ...
    
  2. Generate sam mask

    • Use processed data:

      Download the mask data from CowTransfer or Google Driver.

    • Or, generate data from scratch

      1. download segment-anything code, and download the * vit_h* checkpoint.

        git clone https://github.com/facebookresearch/segment-anything.git
        
      2. generate mask data in RLE format by sam

        python scripts/amg.py \
        --checkpoint weights/sam_vit_h_4b8939.pth \
        --model-type vit_h \
        --input data/sr/DF2K/DF2K_train_HR \
        --output data/sam_out/DF2K/DF2K_train_HR \
        --convert-to-rle
        
      3. use SPE to embedded the RLE format mask

        python scripts/merge_mask_to_one.py \
        --input data/sam_out/DF2K/DF2K_train_HR \
        --output data/sam_embed/DF2K/DF2K_train_HR
        
  3. build bin dataset

    python data_gen/df2k.py --config configs/data/df2k4x_sam.yaml
    

Benchmark dataset

  1. download the dataset. e.g Set5, Set14, Urban100, Manga109, BSDS100

  2. change the data_name and data_path in data_gen/benchmark.py, and run:

    python data_gen/benchmark.py --config configs/data/df2k4x_sam.yaml
    

Training

  1. download rrdb pretrain model from CowTransfer or Google Driver, and move the weight to ./weights/rrdb_div2k.ckpt

  2. train diffusion model

    python tasks/trainer.py \
    --config configs/sam/sam_diffsr_df2k4x.yaml \
    --exp_name sam_diffsr_df2k4x \
    --reset \
    --hparams="rrdb_ckpt=weights/rrdb_div2k.ckpt" \
    --work_dir exp/
    

Evaluation

Inference

python tasks/infer.py \
--config configs/sam/sam_diffsr_df2k4x.yaml \
--img_dir your/lr/img/path \
--save_dir your/sr/img/save/path \
--ckpt_path model_ckpt_steps_400000.ckpt

Citation

If you find this project useful in your research, please consider cite:

@article{wang2024sam,
  title={SAM-DiffSR: Structure-Modulated Diffusion Model for Image Super-Resolution},
  author={Wang, Chengcheng and Hao, Zhiwei and Tang, Yehui and Guo, Jianyuan and Yang, Yujie and Han, Kai and Wang, Yunhe},
  journal={arXiv preprint arXiv:2402.17133},
  year={2024}
}

Acknowledgement

The implementation is based on LeiaLi/SRDiff. Thanks for their open source code.