Home

Awesome

Score-Based Generative Modeling through Stochastic Differential Equations

PWC

This repo contains a PyTorch implementation for the paper Score-Based Generative Modeling through Stochastic Differential Equations

by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole


We propose a unified framework that generalizes and improves previous work on score-based generative models through the lens of stochastic differential equations (SDEs). In particular, we can transform data to a simple noise distribution with a continuous-time stochastic process described by an SDE. This SDE can be reversed for sample generation if we know the score of the marginal distributions at each intermediate time step, which can be estimated with score matching. The basic idea is captured in the figure below:

schematic

Our work enables a better understanding of existing approaches, new sampling algorithms, exact likelihood computation, uniquely identifiable encoding, latent code manipulation, and brings new conditional generation abilities (including but not limited to class-conditional generation, inpainting and colorization) to the family of score-based generative models.

All combined, we achieved an FID of 2.20 and an Inception score of 9.89 for unconditional generation on CIFAR-10, as well as high-fidelity generation of 1024px Celeba-HQ images (samples below). In addition, we obtained a likelihood value of 2.99 bits/dim on uniformly dequantized CIFAR-10 images.

FFHQ samples

What does this code do?

Aside from the NCSN++ and DDPM++ models in our paper, this codebase also re-implements many previous score-based models in one place, including NCSN from Generative Modeling by Estimating Gradients of the Data Distribution, NCSNv2 from Improved Techniques for Training Score-Based Generative Models, and DDPM from Denoising Diffusion Probabilistic Models.

It supports training new models, evaluating the sample quality and likelihoods of existing models. We carefully designed the code to be modular and easily extensible to new SDEs, predictors, or correctors.

Integration with 🤗 Diffusers library

Most models are now also available in 🧨 Diffusers and accesible via the ScoreSdeVE pipeline.

Diffusers allows you to test score sde based models in PyTorch in just a couple lines of code.

You can install diffusers as follows:

pip install diffusers torch accelerate

And then try out the models with just a couple lines of code:

from diffusers import DiffusionPipeline

model_id = "google/ncsnpp-ffhq-1024"

# load model and scheduler
sde_ve = DiffusionPipeline.from_pretrained(model_id)

# run pipeline in inference (sample random noise and denoise)
image = sde_ve().images[0]


# save image
image[0].save("sde_ve_generated_image.png")

More models can be found directly on the Hub.

JAX version

Please find a JAX implementation here, which additionally supports class-conditional generation with a pre-trained classifier, and resuming an evalution process after pre-emption.

JAX vs. PyTorch

In general, this PyTorch version consumes less memory but runs slower than JAX. Here is a benchmark on training an NCSN++ cont. model with VE SDE. Hardware is 4x Nvidia Tesla V100 GPUs (32GB)

FrameworkTime (second per step)Memory usage in total (GB)
PyTorch0.5620.6
JAX (n_jitted_steps=1)0.3029.7
JAX (n_jitted_steps=5)0.2074.8

How to run the code

Dependencies

Run the following to install a subset of necessary python packages for our code

pip install -r requirements.txt

Stats files for quantitative evaluation

We provide the stats file for CIFAR-10. You can download cifar10_stats.npz and save it to assets/stats/. Check out #5 on how to compute this stats file for new datasets.

Usage

Train and evaluate our models through main.py.

main.py:
  --config: Training configuration.
    (default: 'None')
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval>: Running mode: train or eval
  --workdir: Working directory

How to extend the code

Pretrained checkpoints

All checkpoints are provided in this Google drive.

Instructions: You may find two checkpoints for some models. The first checkpoint (with a smaller number) is the one that we reported FID scores in our paper's Table 3 (also corresponding to the FID and IS columns in the table below). The second checkpoint (with a larger number) is the one that we reported likelihood values and FIDs of black-box ODE samplers in our paper's Table 2 (also FID(ODE) and NNL (bits/dim) columns in the table below). The former corresponds to the smallest FID during the course of training (every 50k iterations). The later is the last checkpoint during training.

Per Google's policy, we cannot release our original CelebA and CelebA-HQ checkpoints. That said, I have re-trained models on FFHQ 1024px, FFHQ 256px and CelebA-HQ 256px with personal resources, and they achieved similar performance to our internal checkpoints.

Here is a detailed list of checkpoints and their results reported in the paper. FID (ODE) corresponds to the sample quality of black-box ODE solver applied to the probability flow ODE.

Checkpoint pathFIDISFID (ODE)NNL (bits/dim)
ve/cifar10_ncsnpp/2.459.73--
ve/cifar10_ncsnpp_continuous/2.389.83--
ve/cifar10_ncsnpp_deep_continuous/2.209.89--
vp/cifar10_ddpm/3.24-3.373.28
vp/cifar10_ddpm_continuous--3.693.21
vp/cifar10_ddpmpp2.789.64--
vp/cifar10_ddpmpp_continuous2.559.583.933.16
vp/cifar10_ddpmpp_deep_continuous2.419.683.083.13
subvp/cifar10_ddpm_continuous--3.563.05
subvp/cifar10_ddpmpp_continuous2.619.563.163.02
subvp/cifar10_ddpmpp_deep_continuous2.419.572.922.99
Checkpoint pathSamples
ve/bedroom_ncsnpp_continuousbedroom_samples
ve/church_ncsnpp_continuouschurch_samples
ve/ffhq_1024_ncsnpp_continuousffhq_1024
ve/ffhq_256_ncsnpp_continuousffhq_256_samples
ve/celebahq_256_ncsnpp_continuouscelebahq_256_samples

Demonstrations and tutorials

LinkDescription
Open In ColabLoad our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis (JAX + FLAX)
Open In ColabLoad our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis (PyTorch)
Open In ColabTutorial of score-based generative models in JAX + FLAX
Open In ColabTutorial of score-based generative models in PyTorch

Tips

References

If you find the code useful for your research, please consider citing

@inproceedings{
  song2021scorebased,
  title={Score-Based Generative Modeling through Stochastic Differential Equations},
  author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=PxTIG12RRHS}
}

This work is built upon some previous papers which might also interest you: