Awesome
Extended Analytic-DPM
-
This is the official implementation for Estimating the Optimal Covariance with Imperfect Mean in Diffusion Probabilistic Models (Accepted in ICML 2022). It extends Analytic-DPM under the following two settings:
- The reverse process adpots complicated covariance matrices dependent to states, instead of simple scalar variances (which motivates the SN-DPM in the paper).
- The score-based model has some error w.r.t. the exact score function (which motivates NPR-DPM in the paper).
-
This codebase also reimplements Analytic-DPM and reproduces its most results. The pretrained DPMs used in the Analytic-DPM paper are provided <a href="#pretrained_dpm">here</a>, and have already been converted to a format that can be directly used for this codebase. We also additionally applies Analytic-DPM to score-based SDE.
-
Models and FID statistics are available <a href="#model">here</a> to reproduce results in this paper.
Dependencies
The codebase is based on pytorch
. The dependencies are listed below.
pip install pytorch>=1.9.0 torchvision ml-collections ninja tensorboard
Basic usage
The basic usage for training is
python run_train.py --pretrained_path path/to/pretrained_dpm --dataset dataset --workspace path/to/working_directory $train_hparams
pretrained_path
is the path to a pretrained diffusion probabilistic model (DPM). <a href="#pretrained_dpm">Here</a> provide all pretrained DPMs used in this work.dataset
represents the training dataset, one of <cifar10
|celeba64
|imagenet64
|lsun_bedroom
>.workspace
is the place to put training outputs, e.g., logs and middle checkpoints.train_hparams
specify other hyperparameters used in training. <a href="#model">Here</a> liststrain_hparams
for all models.
The basic usage for evaluation is
python run_eval.py --pretrained_path path/to/evaluated_model --dataset dataset --workspace path/to/working_directory \
--phase phase --sample_steps sample_steps --batch_size batch_size --method method $eval_hparams
pretrained_path
is the path to a model to evaluate. <a href="#model">Here</a> provide all models evaluated in this work.dataset
represents the dataset the model is trained on, one of <cifar10
|celeba64
|imagenet64
|lsun_bedroom
>.workspace
is the place to put evaluation outputs, e.g., logs, samples and bpd values.phase
specifies running sampling or likelihood evaluation, one of <sample4test
|nll4test
>.sample_steps
is the number of steps to run during inference, the samller this value the faster the inference.batch_size
is the batch size, e.g., 500.method
specifies the type of the model, one of:pred_eps
the original DPM (i.e., a noise prediction model) with discrete timestepspred_eps_eps2_pretrained
the SN-DPM with discrete timestepspred_eps_epsc_pretrained
the NPR-DPM with discrete timestepspred_eps_ct2dt
the original (i.e., a noise prediction model) with continuous timesteps (i.e., a score-based SDE)pred_eps_eps2_pretrained_ct2dt
the SN-DPM with continuous timestepspred_eps_epsc_pretrained_ct2dt
the NPR-DPM with continuous timesteps
eval_hparams
specifies other optional hyperparameters used in evaluation.- <a href="#evaluation">Here</a> lists
method
andeval_hparams
for NPR/SN-DPM and Analytic-DPM results in this paper.
Models and FID statistics
<a id="model"/>Here is the list of NPR-DPMs and SN-DPMs trained in this work. These models only train an additional prediction head in the last layer of a pretrained diffusion probabilistic model (DPM).
NPR/SN-DPM | Pretrained DPM | train_hparams |
---|---|---|
CIFAR10 (LS), NPR-DPM | CIFAR10 (LS) | "--method pred_eps_epsc_pretrained" |
CIFAR10 (LS), SN-DPM | CIFAR10 (LS) | "--method pred_eps_eps2_pretrained" |
CIFAR10 (CS), NPR-DPM | CIFAR10 (CS) | "--method pred_eps_epsc_pretrained --schedule cosine_1000" |
CIFAR10 (CS), SN-DPM | CIFAR10 (CS) | "--method pred_eps_eps2_pretrained --schedule cosine_1000" |
CIFAR10 (VP SDE), NPR-DPM | CIFAR10 (VP SDE) | "--method pred_eps_epsc_pretrained_ct --sde vpsde" |
CIFAR10 (VP SDE), SN-DPM | CIFAR10 (VP SDE) | "--method pred_eps_eps2_pretrained_ct --sde vpsde" |
CelebA 64x64, NPR-DPM | CelebA 64x64 | "--method pred_eps_epsc_pretrained" |
CelebA 64x64, SN-DPM | CelebA 64x64 | "--method pred_eps_eps2_pretrained" |
ImageNet 64x64, NPR-DPM | ImageNet 64x64 | "--method pred_eps_epsc_pretrained --mode simple" |
ImageNet 64x64, SN-DPM | ImageNet 64x64 | "--method pred_eps_eps2_pretrained --mode complex" |
LSUN Bedroom, NPR-DPM | LSUN Bedroom | "--method pred_eps_epsc_pretrained --mode simple" |
LSUN Bedroom, SN-DPM | LSUN Bedroom | "--method pred_eps_eps2_pretrained --mode complex" |
Here is the list of pretrained DPMs, collected from prior works. They are converted to a format that can be directly used for this codebase. <a id="pretrained_dpm"/>
Pretrained DPM | Expected mean squared norm (ms_eps ) <br> (Used in Analytic-DPM) | From |
---|---|---|
CIFAR10 (LS) | Link | Analytic-DPM |
CIFAR10 (CS) | Link | Analytic-DPM |
CIFAR10 (VP SDE) | Link | score-sde |
CelebA 64x64 | Link | DDIM |
ImageNet 64x64 | Link | Improved DDPM |
LSUN Bedroom | Link | pytorch_diffusion |
This link provides precalculated FID statistics on CIFAR10, CelebA 64x64, ImageNet 64x64 and LSUN Bedroom. They are computed following Appendix F.2 in Analytic-DPM.
Evaluation Hyperparamters for NPR/SN-DPM and Analytic-DPM
<a id="evaluation"/>Note: Analytic-DPM needs to precalculate the expected mean squared norm of noise prediction model (ms_eps
), which is provided <a href="#pretrained_dpm">here</a>. Specify their path by --ms_eps_path
.
- Sampling experiments on CIFAR10 (LS) or CelebA 64x64, Table 1 in the paper:
method | eval_hparams | |
---|---|---|
NPR-DDPM | pred_eps_epsc_pretrained | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2" |
SN-DDPM | pred_eps_eps2_pretrained | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2" |
Analytic-DDPM | pred_eps | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --ms_eps_path ms_eps_path" |
NPR-DDIM | pred_eps_epsc_pretrained | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0" |
SN-DDIM | pred_eps_eps2_pretrained | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0" |
Analytic-DDIM | pred_eps | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --ms_eps_path ms_eps_path" |
- Sampling experiments on CIFAR10 (CS), Table 1 in the paper:
method | eval_hparams | |
---|---|---|
NPR-DDPM | pred_eps_epsc_pretrained | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000" |
SN-DDPM | pred_eps_eps2_pretrained | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000" |
Analytic-DDPM | pred_eps | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000 --ms_eps_path ms_eps_path" |
NPR-DDIM | pred_eps_epsc_pretrained | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000" |
SN-DDIM | pred_eps_eps2_pretrained | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000" |
Analytic-DDIM | pred_eps | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000 --ms_eps_path ms_eps_path" |
- Sampling experiments on CIFAR10 (VP SDE), Table 1 in the paper:
method | eval_hparams | |
---|---|---|
NPR-DDPM | pred_eps_epsc_pretrained_ct2dt | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --schedule vpsde_1000" |
SN-DDPM | pred_eps_eps2_pretrained_ct2dt | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --schedule vpsde_1000" |
Analytic-DDPM | pred_eps_ct2dt | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --schedule vpsde_1000 --ms_eps_path ms_eps_path" |
NPR-DDIM | pred_eps_epsc_pretrained_ct2dt | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule vpsde_1000" |
SN-DDIM | pred_eps_eps2_pretrained_ct2dt | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule vpsde_1000" |
Analytic-DDIM | pred_eps_ct2dt | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule vpsde_1000 --ms_eps_path ms_eps_path" |
- Sampling experiments on ImageNet 64x64, Table 1 in the paper:
method | eval_hparams | |
---|---|---|
NPR-DDPM | pred_eps_epsc_pretrained | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --mode simple" |
SN-DDPM | pred_eps_eps2_pretrained | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --mode complex" |
Analytic-DDPM | pred_eps | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --ms_eps_path ms_eps_path" |
NPR-DDIM | pred_eps_epsc_pretrained | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --mode simple" |
SN-DDIM | pred_eps_eps2_pretrained | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --mode complex" |
Analytic-DDIM | pred_eps | "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --ms_eps_path ms_eps_path" |
- Likelihood experiments on CIFAR10 (LS) or CelebA 64x64, Table 3 in the paper:
method | eval_hparams | |
---|---|---|
NPR-DDPM | pred_eps_epsc_pretrained | "--rev_var_type optimal" |
Analytic-DDPM | pred_eps | "--rev_var_type optimal --ms_eps_path ms_eps_path" |
- Likelihood experiments on CIFAR10 (CS), Table 3 in the paper:
method | eval_hparams | |
---|---|---|
NPR-DDPM | pred_eps_epsc_pretrained | "--rev_var_type optimal --schedule cosine_1000" |
Analytic-DDPM | pred_eps | "--rev_var_type optimal --schedule cosine_1000 --ms_eps_path ms_eps_path" |
- Likelihood experiments on ImageNet 64x64, Table 3 in the paper:
method | eval_hparams | |
---|---|---|
NPR-DDPM | pred_eps_epsc_pretrained | "--rev_var_type optimal --mode simple" |
Analytic-DDPM | pred_eps | "--rev_var_type optimal --ms_eps_path ms_eps_path" |
This implementation is based on / inspired by
-
Analytic-DPM (provide the code structure)
-
pytorch_diffusion (provide codes of models for CelebA64x64 and LSUN Bedroom)
-
Improved DDPM (provide codes of models for CIFAR10 and Imagenet64x64)
-
score-sde (provide codes of models for CIFAR10)
-
pytorch-fid (provide the official implementation of FID to PyTorch)