Home

Awesome

Extended Analytic-DPM

Dependencies

The codebase is based on pytorch. The dependencies are listed below.

pip install pytorch>=1.9.0 torchvision ml-collections ninja tensorboard

Basic usage

The basic usage for training is

python run_train.py --pretrained_path path/to/pretrained_dpm --dataset dataset --workspace path/to/working_directory $train_hparams

The basic usage for evaluation is

python run_eval.py --pretrained_path path/to/evaluated_model --dataset dataset --workspace path/to/working_directory \
    --phase phase --sample_steps sample_steps --batch_size batch_size --method method $eval_hparams

Models and FID statistics

<a id="model"/>

Here is the list of NPR-DPMs and SN-DPMs trained in this work. These models only train an additional prediction head in the last layer of a pretrained diffusion probabilistic model (DPM).

NPR/SN-DPMPretrained DPMtrain_hparams
CIFAR10 (LS), NPR-DPMCIFAR10 (LS)"--method pred_eps_epsc_pretrained"
CIFAR10 (LS), SN-DPMCIFAR10 (LS)"--method pred_eps_eps2_pretrained"
CIFAR10 (CS), NPR-DPMCIFAR10 (CS)"--method pred_eps_epsc_pretrained --schedule cosine_1000"
CIFAR10 (CS), SN-DPMCIFAR10 (CS)"--method pred_eps_eps2_pretrained --schedule cosine_1000"
CIFAR10 (VP SDE), NPR-DPMCIFAR10 (VP SDE)"--method pred_eps_epsc_pretrained_ct --sde vpsde"
CIFAR10 (VP SDE), SN-DPMCIFAR10 (VP SDE)"--method pred_eps_eps2_pretrained_ct --sde vpsde"
CelebA 64x64, NPR-DPMCelebA 64x64"--method pred_eps_epsc_pretrained"
CelebA 64x64, SN-DPMCelebA 64x64"--method pred_eps_eps2_pretrained"
ImageNet 64x64, NPR-DPMImageNet 64x64"--method pred_eps_epsc_pretrained --mode simple"
ImageNet 64x64, SN-DPMImageNet 64x64"--method pred_eps_eps2_pretrained --mode complex"
LSUN Bedroom, NPR-DPMLSUN Bedroom"--method pred_eps_epsc_pretrained --mode simple"
LSUN Bedroom, SN-DPMLSUN Bedroom"--method pred_eps_eps2_pretrained --mode complex"

Here is the list of pretrained DPMs, collected from prior works. They are converted to a format that can be directly used for this codebase. <a id="pretrained_dpm"/>

Pretrained DPMExpected mean squared norm (ms_eps) <br> (Used in Analytic-DPM)From
CIFAR10 (LS)LinkAnalytic-DPM
CIFAR10 (CS)LinkAnalytic-DPM
CIFAR10 (VP SDE)Linkscore-sde
CelebA 64x64LinkDDIM
ImageNet 64x64LinkImproved DDPM
LSUN BedroomLinkpytorch_diffusion

This link provides precalculated FID statistics on CIFAR10, CelebA 64x64, ImageNet 64x64 and LSUN Bedroom. They are computed following Appendix F.2 in Analytic-DPM.

Evaluation Hyperparamters for NPR/SN-DPM and Analytic-DPM

<a id="evaluation"/>

Note: Analytic-DPM needs to precalculate the expected mean squared norm of noise prediction model (ms_eps), which is provided <a href="#pretrained_dpm">here</a>. Specify their path by --ms_eps_path.

methodeval_hparams
NPR-DDPMpred_eps_epsc_pretrained"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2"
SN-DDPMpred_eps_eps2_pretrained"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2"
Analytic-DDPMpred_eps"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --ms_eps_path ms_eps_path"
NPR-DDIMpred_eps_epsc_pretrained"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0"
SN-DDIMpred_eps_eps2_pretrained"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0"
Analytic-DDIMpred_eps"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --ms_eps_path ms_eps_path"
methodeval_hparams
NPR-DDPMpred_eps_epsc_pretrained"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000"
SN-DDPMpred_eps_eps2_pretrained"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000"
Analytic-DDPMpred_eps"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000 --ms_eps_path ms_eps_path"
NPR-DDIMpred_eps_epsc_pretrained"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000"
SN-DDIMpred_eps_eps2_pretrained"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000"
Analytic-DDIMpred_eps"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000 --ms_eps_path ms_eps_path"
methodeval_hparams
NPR-DDPMpred_eps_epsc_pretrained_ct2dt"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --schedule vpsde_1000"
SN-DDPMpred_eps_eps2_pretrained_ct2dt"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --schedule vpsde_1000"
Analytic-DDPMpred_eps_ct2dt"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --schedule vpsde_1000 --ms_eps_path ms_eps_path"
NPR-DDIMpred_eps_epsc_pretrained_ct2dt"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule vpsde_1000"
SN-DDIMpred_eps_eps2_pretrained_ct2dt"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule vpsde_1000"
Analytic-DDIMpred_eps_ct2dt"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule vpsde_1000 --ms_eps_path ms_eps_path"
methodeval_hparams
NPR-DDPMpred_eps_epsc_pretrained"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --mode simple"
SN-DDPMpred_eps_eps2_pretrained"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --mode complex"
Analytic-DDPMpred_eps"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --ms_eps_path ms_eps_path"
NPR-DDIMpred_eps_epsc_pretrained"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --mode simple"
SN-DDIMpred_eps_eps2_pretrained"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --mode complex"
Analytic-DDIMpred_eps"--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --ms_eps_path ms_eps_path"
methodeval_hparams
NPR-DDPMpred_eps_epsc_pretrained"--rev_var_type optimal"
Analytic-DDPMpred_eps"--rev_var_type optimal --ms_eps_path ms_eps_path"
methodeval_hparams
NPR-DDPMpred_eps_epsc_pretrained"--rev_var_type optimal --schedule cosine_1000"
Analytic-DDPMpred_eps"--rev_var_type optimal --schedule cosine_1000 --ms_eps_path ms_eps_path"
methodeval_hparams
NPR-DDPMpred_eps_epsc_pretrained"--rev_var_type optimal --mode simple"
Analytic-DDPMpred_eps"--rev_var_type optimal --ms_eps_path ms_eps_path"

This implementation is based on / inspired by