Home

Awesome

<div align="center"> <h1> What Matters When Repurposing Diffusion Models for General Dense Perception Tasks?</h1>

Former Title: "Diffusion Models Trained with Large Data Are Transferable Visual Models"

Guangkai Xu,   Yongtao Ge,   Mingyu Liu,   Chengxiang Fan,  <br> Kangyang Xie,   Zhiyue Zhao,   Hao Chen,   Chunhua Shen,  

Zhejiang University

HuggingFace (Space) | HuggingFace (Model) | arXiv

🔥 Fine-tune diffusion models for perception tasks, and inference with only one step! ✈️

</div> <div align="center"> <img width="800" alt="image" src="figs/pipeline.jpg"> </div>

📢 News

📚 Download Resource Summary

🖥️ Dependencies

conda create -n genpercept python=3.10
conda activate genpercept
pip install -r requirements.txt
pip install -e .

🚀 Inference

Using Command-line Scripts

Download the stable-diffusion-2-1 and our trained models from HuggingFace and put the checkpoints under ./pretrained_weights/ and ./weights/, respectively. You can download them with the script script/download_sd21.sh and script/download_weights.sh, or download the weights of depth, normal, Dichotomous Image Segmentation, matting, segmentation, disparity, disparity_dpt_head seperately.

Then, place images in the ./input/ dictionary. We offer demo images in Huggingface, and you can also download with the script script/download_sample_data.sh. Then, run inference with scripts as below.

# Depth
source script/infer/main_paper/inference_genpercept_depth.sh
# Normal
source script/infer/main_paper/inference_genpercept_normal.sh
# Dis
source script/infer/main_paper/inference_genpercept_dis.sh
# Matting
source script/infer/main_paper/inference_genpercept_matting.sh
# Seg
source script/infer/main_paper/inference_genpercept_seg.sh
# Disparity
source script/infer/main_paper/inference_genpercept_disparity.sh
# Disparity_dpt_head
source script/infer/main_paper/inference_genpercept_disparity_dpt_head.sh

If you would like to change the input folder path, unet path, and output path, input these parameters like:

# Assign a values
input_rgb_dir=...
unet=...
output_dir=...
# Take depth as example
source script/infer/main_paper/inference_genpercept_depth.sh $input_rgb_dir $unet $output_dir

For a general inference script, please see script/infer/inference_general.sh in detail.

Thanks to our one-step perception paradigm, the inference process runs much faster. (Around 0.4s for each image on an A800 GPU card.)

Using torch.hub

TODO

<!-- GenPercept models can be easily used with torch.hub for quick integration into your Python projects. Here's how to use the models for normal estimation, depth estimation, and segmentation: #### Normal Estimation ```python import torch import cv2 import numpy as np # Load the normal predictor model from torch hub normal_predictor = torch.hub.load("hugoycj/GenPercept-hub", "GenPercept_Normal", trust_repo=True) # Load the input image using OpenCV image = cv2.imread("path/to/your/image.jpg", cv2.IMREAD_COLOR) # Use the model to infer the normal map from the input image with torch.inference_mode(): normal = normal_predictor.infer_cv2(image) # Save the output normal map to a file cv2.imwrite("output_normal_map.png", normal) ``` #### Depth Estimation ```python import torch import cv2 # Load the depth predictor model from torch hub depth_predictor = torch.hub.load("hugoycj/GenPercept-hub", "GenPercept_Depth", trust_repo=True) # Load the input image using OpenCV image = cv2.imread("path/to/your/image.jpg", cv2.IMREAD_COLOR) # Use the model to infer the depth map from the input image with torch.inference_mode(): depth = depth_predictor.infer_cv2(image) # Save the output depth map to a file cv2.imwrite("output_depth_map.png", depth) ``` #### Segmentation ```python import torch import cv2 # Load the segmentation predictor model from torch hub seg_predictor = torch.hub.load("hugoycj/GenPercept-hub", "GenPercept_Segmentation", trust_repo=True) # Load the input image using OpenCV image = cv2.imread("path/to/your/image.jpg", cv2.IMREAD_COLOR) # Use the model to infer the segmentation map from the input image with torch.inference_mode(): segmentation = seg_predictor.infer_cv2(image) # Save the output segmentation map to a file cv2.imwrite("output_segmentation_map.png", segmentation) ``` -->

🔥 Train

NOTE: We implement the training with the accelerate library, but find a worse training accuracy with multi gpus compared to one gpu, with the same training effective_batch_size and max_iter. Your assistance in resolving this issue would be greatly appreciated. Thank you very much!

Preparation

Datasets: TODO

Place training datasets unser datasets/

Download the stable-diffusion-2-1 from HuggingFace and put the checkpoints under ./pretrained_weights/. You can also download with the script script/download_sd21.sh.

Start Training

The reproduction training scripts in arxiv v3 paper is released in script/, whose configs are stored in config/. Models with max_train_batch_size > 2 are trained on an H100 and max_train_batch_size <= 2 on an RTX 4090. Run the train script:

# Take depth training of main paper as an example
source script/train_sd21_main_paper/sd21_train_accelerate_genpercept_1card_ensure_depth_bs8_per_accu_pixel_mse_ssi_grad_loss.sh

🎖️ Eval

Preparation

  1. Download evaluation datasets and place them in datasets_eval.
  2. Download our trained models of main paper and ablation study in Section 3 of arxiv v3 paper, and place them in weights/genpercept-exps.

Start Evaluation

The evaluation scripts are stored in script/eval_sd21.

# Take "ensemble1 + step1" as an example
source script/eval_sd21/eval_ensemble1_step1/0_infer_eval_all.sh

📖 Recommanded Works

👍 Results in Paper

Depth and Surface Normal

<div align="center"> <img width="800" alt="image" src="figs/demo_depth_normal_new.jpg"> </div>

Dichotomous Image Segmentation

<div align="center"> <img width="800" alt="image" src="figs/demo_dis_new.jpg"> </div>

Image Matting

<div align="center"> <img width="800" alt="image" src="figs/demo_matting.jpg"> </div>

Image Segmentation

<div align="center"> <img width="800" alt="image" src="figs/demo_seg.jpg"> </div>

🎫 License

For non-commercial academic use, this project is licensed under the 2-clause BSD License. For commercial use, please contact Chunhua Shen.

🎓 Citation

@article{xu2024diffusion,
  title={What Matters When Repurposing Diffusion Models for General Dense Perception Tasks?},
  author={Xu, Guangkai and Ge, Yongtao and Liu, Mingyu and Fan, Chengxiang and Xie, Kangyang and Zhao, Zhiyue and Chen, Hao and Shen, Chunhua},
  journal={arXiv preprint arXiv:2403.06090},
  year={2024}
}