Home

Awesome

Figure

FiT: Flexible Vision Transformer for Diffusion Model

<p align="center"> šŸ“ƒ <a href="https://arxiv.org/pdf/2402.12376.pdf" target="_blank">FiT Paper</a> ā€¢ šŸ“¦ <a href="https://huggingface.co/InfImagine/FiT" target="_blank">FiT Checkpoint</a> <br> ā€¢ šŸ“ƒ <a href="https://arxiv.org/pdf/2410.13925" target="_blank">FiTv2 Paper</a> ā€¢ šŸ“¦ <a href="https://huggingface.co/InfImagine/FiTv2" target="_blank">FiTv2 Checkpoint</a> <br> </p>

This is the official repo which contains PyTorch model definitions, pre-trained weights and sampling code for our flexible vision transformer (FiT). FiT is a diffusion transformer based model which can generate images at unrestricted resolutions and aspect ratios.

The core features will include:

Why we need FiT?

Stay tuned for this project! šŸ˜†

Setup

First, download and setup the repo:

git clone https://github.com/whlzy/FiT.git
cd FiT

Installation

conda create -n fit_env python=3.10
pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu118
pip install xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt
pip install -e .

Sample

Basic Sampling

Basically, the model is trained with images whose $H\times W\leqslant 256\times256$. Our FiTv1-XL/2 model and FiTv2-XL/2 model are trained with batch size of 256 for 2000K steps. Our FiTv2-3B/2 model is trained with batch size of 256 for 1000K steps.

The pre-trained FiT models can be downloaded directyl from huggingface:

FiT ModelCheckpointFID-256x256FID-320x320Model SizeGFlOPS
FiTv1-XL/2CKPT4.215.11824M153
FiTv2-XL/2CKPT2.263.55671M147
FiTv2-3B/2CKPT2.153.223B653

Downloading the checkpoints

Downloading via wget:

mkdir checkpoints

wget -c "https://huggingface.co/InfImagine/FiT/blob/main/FiTv1_xl/model_ema.bin" -O checkpoints/fitv1_xl.bin

wget -c "https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_XL/model_ema.safetensors?download=true" -O checkpoints/fitv2_xl.safetensors

wget -c "https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_3B/model_ema.bin?download=true" -O checkpoints/fitv2_3B.bin

Sampling 256x256 Images

Sampling with FiTv1-XL/2 for $256\times 256$ Images:

python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fit_ddp.py --num-fid-samples 50000 --cfgdir configs/fit/config_fit_xl.yaml --ckpt checkpoints/fitv1_xl.bin --image-height 256 --image-width 256 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32

Sampling with FiTv2-XL/2 for $256\times 256$ Images:

python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_xl.yaml --ckpt checkpoints/fitv2_xl.safetensors --image-height 256 --image-width 256 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32  --sampler-mode ODE

Sampling with FiTv2-3B/2 for $256\times 256$ Images:

python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_3B.yaml --ckpt checkpoints/fitv2_3B.bin --image-height 256 --image-width 256 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32  --sampler-mode ODE

Note that NUM_NODE, NUM_GPU and MASTER_PORT need to be specified.

Sampling Images with arbitrary resolutions

We can assign the image-height and image-width with any value we want. And we need to specify the original maximum positional embedding length (ori-max-pe-len) and the interpolation method. We show some examples as follows.

Sampling with FiTv2-XL/2 for $160\times 320$ images:

python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_xl.yaml --ckpt checkpoints/fitv2_xl.safetensors --image-height 160 --image-width 320 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32  --sampler-mode ODE --ori-max-pe-len 16 --interpolation dynntk --decouple

Sampling with FiTv2-XL/2 for $320\times 320$ images:

python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_xl.yaml --ckpt checkpoints/fitv2_xl.safetensors --image-height 320 --image-width 320 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32  --sampler-mode ODE --ori-max-pe-len 16 --interpolation ntkpro2 --decouple

Note that NUM_NODE, NUM_GPU and MASTER_PORT need to be specified.

High-resolution Sampling

For high-resolution image generation, we use images whose $H\times W \leqslant 512\times512$. Our FiTv2-XL/2 is finetuned with batch size of 256 for 400K steps, while FiTv2-3B/2 is finetuned with batch size of 256 for 200K steps.

The high-resolution fine-tuned FiT models can be downloaded directyl from huggingface:

FiT ModelCheckpointFID-512x512FID-320x640Model SizeGFlOPS
FiTv2-HR-XL/2CKPT2.904.87671M147
FiTv2-HR-3B/2CKPT2.414.543B653

Downloading

Downloading via wget:

mkdir checkpoints

wget -c "https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_XL_HRFT/model_ema.safetensors?download=true" -O checkpoints/fitv2_hr_xl.safetensors

wget -c "https://huggingface.co/InfImagine/FiTv2/resolve/main/FiTv2_3B_HRFT/model_ema.safetensors?download=true" -O checkpoints/fitv2_hr_3B.safetensors

Sampling 512x512 Images

Sampling with FiTv2-HR-XL/2 for $512\times 512$ Images:

python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_hr_xl.yaml --ckpt checkpoints/fitv2_hr_xl.safetensors --image-height 512 --image-width 512 --num-sampling-steps 250 --cfg-scale 1.65 --global-seed 0 --per-proc-batch-size 32  --sampler-mode ODE --ori-max-pe-len 16 --interpolation dynntk --decouple

Sampling with FiTv2-HR-3B/2 for $512\times 512$ Images:

python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_hr_3B.yaml --ckpt checkpoints/fitv2_hr_3B.safetensors --image-height 512 --image-width 512 --num-sampling-steps 250 --cfg-scale 1.5 --global-seed 0 --per-proc-batch-size 32  --sampler-mode ODE --ori-max-pe-len 16 --interpolation dynntk --decouple

Note that NUM_NODE, NUM_GPU and MASTER_PORT need to be specified.

Sampling Images with arbitrary resolutions

Sampling with FiTv2-HR-XL/2 for $320\times 640$ images:

python -m torch.distributed.run --nnodes=${NUM_NODE} --nproc_per_node=${NUM_GPU} --rdzv_endpoint localhost:$MASTER_PORT sample_fitv2_ddp.py --num-fid-samples 50000 --cfgdir configs/fitv2/config_fitv2_hr_xl.yaml --ckpt checkpoints/fitv2_hr_xl.safetensors --image-height 320 --image-width 640 --num-sampling-steps 250 --cfg-scale 1.65 --global-seed 0 --per-proc-batch-size 32  --sampler-mode ODE --ori-max-pe-len 16 --interpolation dynntk --decouple

Note that NUM_NODE, NUM_GPU and MASTER_PORT need to be specified.

Evaluations

The sampling generates a folder of samples as well as a .npz file which can be directly used with ADM's TensorFlow evaluation suite to compute FID, Inception Score and other metrics.

Flexible Imagenet Latent Datasets

We use SD-VAE-FT-EMA to encode an image into the latent codes.

Accordingly to our flexible training pipeline, we can deal with images with arbitrary resolutions and aspect ratios. So we preprocess the ImageNet1k dataset according to the original height and width of an image. Conventionally, we set patch size $p$ to $2$ and the downsampling scale $s$ of the VAE encoder is $8$, so an image with $H=256$ and $W=256$ will lead to $\frac{H\times W}{p^2\times s^2} = 256$ tokens.

For our pre-training, we set the maximum token length to $256$, which corresponds image resolution size $S=H=W=256$. For the high-resolution fine-tuning, the token length is $1024$, which corresponds image resolution size $S=H=W=512$. Given an input image $I\in \mathbb{R}^{3\times H \times W}$, and target resolution size $S=256/512$, the preprocessing is:

If H > S and W > S:
  img_resize = Resize(I)
  latent_resize = VAE_Encode(img_resize)
  save(latent_resize)
  img_crop = CenterCrop(Resize(I))
  latent_crop = VAE_Encode(img_crop)
  save(latent_resize)
else:
  img_resize = Resize(I)
  latent_resize = VAE_Encode(img_resize)
  save(latent_resize)

Dataset for Pretraining

All the image latent codes with maximum token length $256$ can be downloaded from here.

bash tools/download_in1k_latents_256.sh

Dataset for High-resolution Fine-tuning

All the image latent codes with maximum token length $1024$ can be downloaded from here.

bash tools/download_in1k_latents_1024.sh

Dataset Architecture

Train

You need to determine the number of node and GPU for your training.

Train FiT and FiTv2 models:

bash tools/train_fit_xl.sh

bash tools/train_fitv2_xl.sh

bash tools/train_fitv2_3B.sh

High-resolution Fine-tuning:

bash tools/train_fitv2_hr_xl.sh

bash tools/train_fitv2_hr_3B.sh

BibTeX

@article{Lu2024FiT,
  title={FiT: Flexible Vision Transformer for Diffusion Model},
  author={Zeyu Lu and Zidong Wang and Di Huang and Chengyue Wu and Xihui Liu and Wanli Ouyang and Lei Bai},
  year={2024},
  journal={arXiv preprint arXiv:2402.12376},
}
@article{wang2024fitv2,
  title={Fitv2: Scalable and improved flexible vision transformer for diffusion model},
  author={Wang, ZiDong and Lu, Zeyu and Huang, Di and Zhou, Cai and Ouyang, Wanli and others},
  journal={arXiv preprint arXiv:2410.13925},
  year={2024}
}