Awesome
Conditional Diffusion MNIST
script.py is a minimal, self-contained implementation of a conditional diffusion model. It learns to generate MNIST digits, conditioned on a class label. The neural network architecture is a small U-Net (pretrained weights also available in this repo). This code is modified from this excellent repo which does unconditional generation. The diffusion model is a Denoising Diffusion Probabilistic Model (DDPM).
<p align = "center"> <img width="400" src="gif_mnist_01.gif"/img> </p> <p align = "center"> Samples generated from the model. </p>The conditioning roughly follows the method described in Classifier-Free Diffusion Guidance (also used in ImageGen). The model infuses timestep embeddings $t_e$ and context embeddings $c_e$ with the U-Net activations at a certain layer $a_L$, via,
<p align = "center"> $a_{L+1} = c_e a_L + t_e.$ </p> (Though in our experimentation, we found variants of this also work, e.g. concatenating embeddings together.)At training time, $c_e$ is randomly set to zero with probability $0.1$, so the model learns to do unconditional generation (say $\psi(z_t)$ for noise $z_t$ at timestep $t$) and also conditional generation (say $\psi(z_t, c)$ for context $c$). This is important as at generation time, we choose a weight, $w \geq 0$, to guide the model to generate examples with the following equation,
<p align = "center"> $\hat{\epsilon}_{t} = (1+w)\psi(z_t, c) - w \psi(z_t).$ </p>Increasing $w$ produces images that are more typical but less diverse.
<p align = "center"> <img width="800" src="guided_mnist.png"/img> </p> <p align = "center"> Samples produced with varying guidance strength, $w$. </p>Training for above models took around 20 epochs (~20 minutes).
pretrained_model.zip
contains pretrained weights.