Home

Awesome

Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers (SiT)<br><sub>Official PyTorch Implementation</sub>

Paper | Project Page | Open In Colab

SiT samples

This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring interpolant models with scalable transformers (SiTs).

Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers<br> Nanye Ma, Mark Goldstein, Michael Albergo, Nicholas Boffi, Eric Vanden-Eijnden, Saining Xie <br>New York University<br>

We present Scalable Interpolant Transformers (SiT), a family of generative models built on the backbone of Diffusion Transformers (DiT). The interpolant framework, which allows for connecting two distributions in a more flexible way than standard diffusion models, makes possible a modular study of various design choices impacting generative models built on dynamical transport: using discrete vs. continuous time learning, deciding the model to learn, choosing the interpolant connecting the distributions, and deploying a deterministic or stochastic sampler. By carefully introducing the above ingredients, SiT surpasses DiT uniformly across model sizes on the conditional ImageNet 256x256 benchmark using the exact same backbone, number of parameters, and GFLOPs. By exploring various diffusion coefficients, which can be tuned separately from learning, SiT achieves an FID-50K score of 2.06.

This repository contains:

Setup

First, download and set up the repo:

git clone https://github.com/willisma/SiT.git
cd SiT

We provide an environment.yml file that can be used to create a Conda environment. If you only want to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.

conda env create -f environment.yml
conda activate SiT

Sampling Open In Colab

More SiT samples

Pre-trained SiT checkpoints. You can sample from our pre-trained SiT models with sample.py. Weights for our pre-trained SiT model will be automatically downloaded depending on the model you use. The script has various arguments to adjust sampler configurations (ODE & SDE), sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our 256x256 SiT-XL model with default ODE setting, you can use:

python sample.py ODE --image-size 256 --seed 1

For convenience, our pre-trained SiT models can be downloaded directly here as well:

SiT ModelImage ResolutionFID-50KInception ScoreGflops
XL/2256x2562.06270.27119
<!-- | [XL/2](https://dl.fbaipublicfiles.com/SiT/models/SiT-XL-2-512x512.pt) | 512x512 | 2.62 | 252.21 | 525 | -->

Custom SiT checkpoints. If you've trained a new SiT model with train.py (see below), you can add the --ckpt argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 256x256 SiT-L/4 model with ODE sampler, run:

python sample.py ODE --model SiT-L/4 --image-size 256 --ckpt /path/to/model.pt

Advanced sampler settings

ODE--atolfloatAbsolute error tolerance
--rtolfloatRelative error tolenrace
--sampling-methodstrSampling methods (refer to torchdiffeq )
SDE--diffusion-formstrForm of SDE's diffusion coefficient (refer to Tab. 2 in paper)
--diffusion-normfloatMagnitude of SDE's diffusion coefficient
--last-stepstrForm of SDE's last step
None - Single SDE integration step
"Mean" - SDE integration step without diffusion coefficient
"Tweedie" - Tweedie's denoising step
"Euler" - Single ODE integration step
--sampling-methodstrSampling methods
"Euler" - First order integration
"Heun" - Second order integration

There are some more options; refer to train_utils.py for details.

Training SiT

We provide a training script for SiT in train.py. To launch SiT-XL/2 (256x256) training with N GPUs on one node:

torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train

Logging. To enable wandb, firstly set WANDB_KEY, ENTITY, and PROJECT as environment variables:

export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"

Then in training command add the --wandb flag:

torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --wandb

Interpolant settings. We also support different choices of interpolant and model predictions. For example, to launch SiT-XL/2 (256x256) with Linear interpolant and noise prediction:

torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --path-type Linear --prediction noise

Resume training. To resume training from custom checkpoint:

torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-L/2 --data-path /path/to/imagenet/train --ckpt /path/to/model.pt

Caution. Resuming training will automatically restore both model, EMA, and optimizer states and training configs to be the same as in the checkpoint.

Evaluation (FID, Inception Score, etc.)

We include a sample_ddp.py script which samples a large number of images from a SiT model in parallel. This script generates a folder of samples as well as a .npz file which can be directly used with ADM's TensorFlow evaluation suite to compute FID, Inception Score and other metrics. For example, to sample 50K images from our pre-trained SiT-XL/2 model over N GPUs under default ODE sampler settings, run:

torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --num-fid-samples 50000

Likelihood. Likelihood evaluation is supported. To calculate likelihood, you can add the --likelihood flag to ODE sampler:

torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --likelihood

Notice that only under ODE sampler likelihood can be calculated; see sample_ddp.py for more details and settings.

Enhancements

Training (and sampling) could likely be speed-up significantly by:

Basic features that would be nice to add:

Precision in likelihood calculation could likely be improved by:

Differences from JAX

Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models. There may be minor differences in results stemming from sampling on different platforms (TPU vs. GPU). We observed that sampling on TPU performs marginally worse than GPU (2.15 FID versus 2.06 in the paper).

License

This project is under the MIT license. See LICENSE for details.