Awesome
<p align="center">Fast Sampling of Diffusion Models with Exponential Integrator</p>
<div align="center"> <a href="https://qsh-zh.github.io/" target="_blank">Qinsheng Zhang</a>   <b>·</b>   <a href="https://yongxin.ae.gatech.edu/" target="_blank">Yongxin Chen</a> <br> <br> <a href="arxiv.org/abs/2204.13902" target="_blank">Paper</a>   <a href="https://qsh-zh.github.io/deis" target="_blank">Project Page</a> </div> <br><br>- 2021-11-17 DEIS accelerates large scale text-to-image eDiff-I and achieves SOTA performance.
Update
- BREAKING CHANGE: v1.0 API changes greatly as we add
ρRK-DEIS
andρAB-DEIS
algorithms and more choice for time scheduling. If you are only interested intAB-DEIS
/iPNDM
or previous codebase, check v0.1
Usage
# for pytorch user
pip install "jax[cpu]"
If diffusion models are trained with continuous time
import jax_deis as deis
def eps_fn(x_t, scalar_t):
vec_t = jnp.ones(x_t.shape[0]) * scalar_t
return eps_model(x_t, vec_t)
# pytorch
# import th_deis as deis
# def eps_fn(x_t, scalar_t):
# vec_t = (th.ones(x_t.shape[0])).float().to(x_t) * scalar_t
# with th.no_grad():
# return eps_model(x_t, vec_t)
# mappings between t and alpha in VPSDE
# we provide popular linear and cos mappings
t2alpha_fn,alpha2t_fn = deis.get_linear_alpha_fns(beta_0=0.01, beta_1=20)
vpsde = deis.VPSDE(
t2alpha_fn,
alpha2t_fn,
sampling_eps, # sampling end time t_0
sampling_T # sampling starting time t_T
)
sampler_fn = deis.get_sampler(
# args for diffusion model
vpsde,
eps_fn,
# args for timestamps scheduling
ts_phase="t", # support "rho", "t", "log"
ts_order=2.0,
num_step=10,
# deis choice
method = "t_ab", # deis sampling algorithms: support "rho_rk", "rho_ab", "t_ab", "ipndm"
ab_order= 3, # greater than 0, used for "rho_ab", "t_ab" algorithms, other algorithms will ignore the arg
rk_method="3kutta" # used for "rho_rk" algorithms, other algorithms will ignore the arg
)
sample = sampler_fn(noise)
If diffusion models are trained with discrete time
#! by default the example assumes sampling
#! from t=len(discrete_alpha) - 1 to t=0
#! totaly len(discrete_alpha) steps if we use delta_t = 1
vpsde = deis.DiscreteVPSDE(discrete_alpha)
A short derivation for DEIS
<details> <summary>Exponential integrator in diffusion model</summary>The key insight of exponential integrator is taking advantage of all math structures present in ODEs. The goal is to reduce discretization error as small as possible.
The math structure in diffusion models includes semilinear structure, the analytic formula for drift and diffusion coefficients.
Below we present a short derivation for applications of the exponential integrator in diffusion model.
Forward SDE
$$ dx = F_tx dt + G_td\mathbf{w} $$
Backward ODE
$$ dx = F_tx dt + 0.5 G_tG_t^T L_t^{-T} \epsilon(x, t) dt $$
where $L_t L_t^{T} = \Sigma_t$ and $\Sigma_t$ are variance of $p_{0t}(x_t | x_0)$.
Exponential Integrator
We can get rid of semilinear structure with Exponential Integrator by introducing a new variable $y$
$$ y_t = \Psi(t) x_t \quad \Psi(t) = \exp{-\int_0^{t} F_\tau d \tau} $$
And ODE is simplified into
$$ \dot{y}_t = 0.5 \Psi(t) G_t G_t^T L_t^{-T} \epsilon(x(y), t) $$
where $x(y)$ maps $y_t$ to $x_t$.
Time scaling
We can take one step further when $F_t, G_t$ are scalars by rescaling time
$$ \dot{v}_\rho = \epsilon(x(v), t(\rho)) $$
where $y_t = v_\rho$ and $d \rho = 0.5 \Psi(t) G_t G_t^T L_t^{-T} dt$. And $x(v)$ maps $v_\rho$ to $x_t$, $t(\rho)$ maps $\rho$ to $t$.
High order solver
By absorbing all math structure, we reach the following ODE
$$ \dot{v}_\rho = \epsilon(x(v), t(\rho)) $$
As RHS is a nerual network, we can not further simplify ODE unless we have knowledge for the black-box function. Then we can use well-established ODE solvers, such as multistep and runge kutta.
</details>Demo
- continuous vpsde Based on score_sde codebase. CIFAR10 images in 7 steps
- discrete vpsde Based on PNDM codebase
Reference
@article{zhang2022fast,
title={Fast Sampling of Diffusion Models with Exponential Integrator},
author={Zhang, Qinsheng and Chen, Yongxin},
journal={arXiv preprint arXiv:2204.13902},
year={2022}
}