Home

Awesome

Efficient Diffusion Training via Min-SNR Weighting Strategy

By Tiankai Hang, Shuyang Gu, Chen Li, Jianmin Bao, Dong Chen, Han Hu, Xin Geng, Baining Guo.

Paper | arXiv | Code

Abstract.

Denoising diffusion models have been a mainstream approach for image generation, however, training these models often suffers from slow convergence. In this paper, we discovered that the slow convergence is partly due to conflicting optimization directions between timesteps. To address this issue, we treat the diffusion training as a multi-task learning problem, and introduce a simple yet effective approach referred to as Min-SNR-$\gamma$. This method adapts loss weights of timesteps based on clamped signal-to-noise ratios, which effectively balances the conflicts among timesteps. Our results demonstrate a significant improvement in converging speed, 3.4x faster than previous weighting strategies. It is also more effective, achieving a new record FID score of 2.06 on the ImageNet 256x256 benchmark using smaller architectures than that employed in previous state-of-the-art.

News

Data Preparation

For CelebA dataset, we follow ScoreSDE to process the data.

For ImageNet dataset, we download it from the official website. For ImageNet-64, we did not adopt offline pre-processing. For ImageNet-256, we cropped the images to 256x256 and compressed them using AutoencoderKL from Diffusers. The compressed latent codes are treated equally as images, except the file extension.

Training

For training with ViT-B model, you should first put the downloaded/processed data above to some path, and set DATA_DIR in the config file vit-b_layer12_lr1e-4_099_099_pred_x0__min_snr_5__fp16_bs8x32.sh. Then you could run like

GPUS=8
BATCH_SIZE_PER_GPU=32
bash configs/in256/vit-b_layer12_lr1e-4_099_099_pred_x0__min_snr_5__fp16_bs8x32.sh $GPUS $BATCH_SIZE_PER_GPU

Sampling with Pre-trained Models

For sampling for ImageNet-256, you could directly run

bash configs/in256/inference.sh

Thanks to the sampling method from Applying Guidance in a Limited Interval Improves Sample and Distribution Quality in Diffusion Models, we achieve a new FID score of 1.57342 on the ImageNet 256x256 benchmark. You can run the following command

bash configs/in256/inference_limited_interval_guidance.sh

For sampling for ImageNet-64, you could directly run

bash configs/in64/inference.sh

Here we use 8 GPUs for sampling. You can change GPUS=8 to GPUS=1 for single GPU evaluation in configs/in256/inference.sh The pre-trained models will be automatically downloaded and FID-50K will be calculated.

Citing Min-SNR Diffusion Training

If you find our work useful for your research, please consider citing our paper. :blush:

@InProceedings{Hang_2023_ICCV,
    author    = {Hang, Tiankai and Gu, Shuyang and Li, Chen and Bao, Jianmin and Chen, Dong and Hu, Han and Geng, Xin and Guo, Baining},
    title     = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
    pages     = {7441-7451}
}

Acknowlegements

This repository is based on openai/guided-diffusion. We adopt the implementation for sampling and FID evaluation using NVlabs/edm.