Home

Awesome

Maximum Likelihood Training for Score-Based Diffusion ODEs by High-Order Denoising Score Matching (ICML 2022)

The official code for the paper Maximum Likelihood Training for Score-Based Diffusion ODEs by High-Order Denoising Score Matching by Cheng Lu, Kaiwen Zheng, Fan Bao, Jianfei Chen, Chongxuan Li and Jun Zhu, published in ICML 2022.

The code implementation is based on score_flow by Yang Song.


Score-based diffusion models include two types: ScoreSDE and ScoreODE. Previous work showed that the weighted combination of first-order score matching losses can upper bound the Kullback–Leibler divergence between the data distribution and the ScoreSDE model distribution. However, the relationship between score matching and ScoreODE is unclear. In this work, we prove that:

In short, The previous work Maximum Likelihood Training of Score-Based Diffusion Models is a method for maximum likelihood training of ScoreSDE (a.k.a. diffusion SDE), and our work is a method for maximum likelihood training of ScoreODE (a.k.a. diffusion ODE).

Code Structure

The code implementation is based on score_flow by Yang Song. We further implement the proposed high-order denoising score matching losses in losses.py.

How to run the code

Dependencies

We use the same denpendencies as score_flow. To install the packages, we recommend the jaxlib==0.1.69. You need to find a corresponding version for your python3 version and cuda version at: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html. For example, to install jaxlib==0.1.69 for python==3.7 and cuda==11.1, you need to firstly download the wheel file:

wget https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.69+cuda111-cp37-none-manylinux2010_x86_64.whl

and then run the following command to install jaxlib:

pip3 install jaxlib-0.1.69+cuda111-cp37-none-manylinux2010_x86_64.whl

After install jaxlib, you need to run to following command to install the other packages:

pip3 install -r requirements.txt

Stats files for quantitative evaluation

We use the same stats files by score_flow for computing FID and Inception scores for CIFAR-10 and ImageNet 32x32. You can find cifar10_stats.npz and imagenet32_stats.npz under the directory assets/stats in Yang Song's Google drive. Download them and save to assets/stats/ in the code repo.

Usage

The running command is the same as score_flow. Here are some common options:

main.py:
  --config: Training configuration.
    (default: 'None')
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval>: Running mode: train or eval. We did not train our model by further variational dequantizations.
  --workdir: Working directory

These functionalities can be configured through config files, or more conveniently, through the command-line support of the ml_collections package.

Configurations for high-order denoising score matching

To set the order of the score matching training losses, set --config.training.score_matching_order to be 1 (the previous first-order) or 2 or 3. Note that for third-order score matching training, the batch size needs to turn smaller to avoid OOM.

Configurations for evaluation

To generate samples and evaluate sample quality, use the --config.eval.enable_sampling flag; to compute log-likelihoods, use the --config.eval.enable_bpd flag, and specify --config.eval.dataset=train/test to indicate whether to compute the likelihoods on the training or test dataset. Turn on --config.eval.bound to evaluate the variational bound for the log-likelihood. Enable --config.eval.dequantizer to use variational dequantization for likelihood computation. --config.eval.num_repeats configures the number of repetitions across the dataset (more can reduce the variance of the likelihoods; default to 5).

Pretrained checkpoints

The pretrained checkpoints can be found in the Released page.

Train high-order DSM by pretrained checkpoints

For VESDE on CIFAR-10, we use the pretrained checkpoints by first-order DSM in score_sde checkpoints.

For VESDE on ImageNet32, as score_sde did not provide the checkpoints, we train the first-order model by ourselves, and then train the model by the high-order DSM. The baseline first-order ImageNet32 models are also provided in the Released page.

For VPSDE, we use the pretrained checkpoints by first-order DSM in score_flow checkpoints.

References

If you find the code useful for your research, please consider citing

@inproceedings{lu2022maximum,
  title={Maximum Likelihood Training for Score-Based Diffusion ODEs by High-Order Denoising Score Matching},
  author={Lu, Cheng and Zheng, Kaiwen and Bao, Fan and Chen, Jianfei and Li, Chongxuan and Zhu, Jun},
  booktitle={International Conference on Machine Learning},
  year={2022}
  organization={PMLR}
}

This work is built upon some previous papers which might also interest you: