Awesome
<p align="center">gDDIM: Generalized denoising diffusion implicit models</p>
<div align="center"> <a href="https://qsh-zh.github.io/" target="_blank">Qinsheng Zhang</a>   <b>·</b>   <a href="https://mtao8.math.gatech.edu/" target="_blank">Molei Tao</a>   <b>·</b>   <a href="https://yongxin.ae.gatech.edu/" target="_blank">Yongxin Chen</a> <br> <br> <a href="https://arxiv.org/abs/2206.05564" target="_blank">Paper</a>   </div> <br><br>TLDR: We unbox the accelerating secret of DDIMs based on Dirac approximation and generalize it to general diffusion models, isotropic and non-isotropic.
<!-- When applied to the critically-damped Langevin diffusion model, it achieves an FID score of 2.26 on CIFAR10 with 50 steps. -->
Setup
The codebase is only tested in docker environment.
Docker
- Dockerfile lists necessary steps and packages to setup training / testing environments.
- We provide a Docker Image in DockerHub
Reproduce results
gDDIM on CLD
Training on cifar10
cd ${gDDIM_PROJECT_FOLDER}/cld_jax
wandb login ${WANDB_KEY}
python main.py --config configs/accr_dcifar10_config.py --mode train --workdir logs/accr_dcifar_nomixed --wandb --config.seed=8
- I have randomly try seed=
1,8,123
. Andseed=8
(checkpoint 15) gives the best FID while the lowest FIDs from other two are slightly high (around 2.30) in CIFAR10.
Eval on cifar10
-
Download CIFAR stats to
${gDDIM_PROJECT_FOLDER}/cld_jax/assets/stats/
. -
We provide pretrain model checkpoint.
the checkpoint has 2.2565 FID in my machine with 50 NFE
- User can evaluate FID via
cd ${gDDIM_PROJECT_FOLDER}/cld_jax
python main.py --config configs/accr_dcifar10_config.py --mode check --result_folder logs/fid --ckpt ${CLD_BEST_PATH} --config.sampling.deis_order=2 --config.sampling.nfe=50
Blur diffusion model
Training on cifar10
cd ${gDDIM_PROJECT_FOLDER}/blur_jax
wandb login ${WANDB_KEY}
python main.py --config configs/ddpm_deep_cifar10_config.py --mode train --workdir logs/ddpm_deep_sigma${sigma}_seed${seed} --wandb --config.model.sigma_blur_max=${sigma} --config.seed=${seed}"
Eval on cifar10
-
Download CIFAR stats to
${gDDIM_PROJECT_FOLDER}/blur_jax/assets/stats/
. -
We provide pretrain model checkpoint.
-
User can evaluate FID via
cd ${gDDIM_PROJECT_FOLDER}/blur_jax
python main.py --config configs/accr_dcifar10_config.py --mode check --result_folder logs/fid --ckpt ${CLD_BEST_PATH} --config.sampling.deis_order=2 --config.sampling.nfe=50
Reference
@misc{zhang2022gddim,
title={gDDIM: Generalized denoising diffusion implicit models},
author={Qinsheng Zhang and Molei Tao and Yongxin Chen},
year={2022},
eprint={2206.05564},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Related works
@inproceedings{song2020denoising,
title={Denoising diffusion implicit models},
author={Song, Jiaming and Meng, Chenlin and Ermon, Stefano},
booktitle={International Conference on Learning Representations (ICLR)},
year={2021}
}
@inproceedings{dockhorn2022score,
title={Score-Based Generative Modeling with Critically-Damped Langevin Diffusion},
author={Tim Dockhorn and Arash Vahdat and Karsten Kreis},
booktitle={International Conference on Learning Representations (ICLR)},
year={2022}
}
@article{hoogeboom2022blurring,
title={Blurring diffusion models},
author={Hoogeboom, Emiel and Salimans, Tim},
journal={arXiv preprint arXiv:2209.05557},
year={2022}
}
Miscellaneous
The project is built upon score-sde developed by Yang Song. Additionally, the sampling code has been adopted from DEIS.