Awesome
[NeurIPS 2024] U-DiTs: Downsample Tokens in U-Shaped Diffusion Transformers
<p align="left"> <a href="https://arxiv.org/abs/2405.02730" alt="arXiv"> <img src="https://img.shields.io/badge/arXiv-2405.02730-b31b1b.svg?style=flat" /></a> <a href="https://huggingface.co/yuchuantian/U-DiT/tree/main" alt="Hugging Face Models"> <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue" /></a> <a href="https://www.modelscope.cn/models/YuchuanTian/U-DiT/files" alt="ModelScope Models"> <img src="https://img.shields.io/badge/ModelScope-Models-blue" /></a> <a href="https://colab.research.google.com/drive/17ZimD7GdK2ZZHRg52_I9PNxDTs0LKd20?usp=sharing" alt="ModelScope Models"> <img src="https://colab.research.google.com/assets/colab-badge.svg" /></a> </p>This is the official implementation of "U-DiTs: Downsample Tokens in U-Shaped Diffusion Transformers".
9/30/2024: U-DiT is cited by Playground V3!
9/26/2024: U-DiT is accepted to NeurIPS 2024!๐๐๐ See you in Vancouver!
Outline
๐ค In this work, we rethink "Could U-Net arch boost DiTs?"
๐ฎ Self-attention with downsampling reduces cost by ~3/4, but improves U-Net performance.
๐ฅณ We develop a series of powerful U-DiTs.
๐ U-DiT-B could outcompete DiT-XL/2 with only 1/6 of its FLOPs.
Preparation
Please run command pip install -r requirements.txt
to install the supporting packages.
(Optional) Please download the VAE from this link. The VAE could be automatically downloaded as well.
Training
Here we provide two ways to train a U-DiT model: 1. train on the original ImageNet dataset; 2. train on preprocessed VAE features (Recommended).
Training Data Preparation Use the original ImageNet dataset + VAE encoder. Firstly, download ImageNet as follows:
imagenet/
โโโtrain/
โ โโโ n01440764
โ โ โโโ n01440764_10026.JPEG
โ โ โโโ n01440764_10027.JPEG
โ โ โโโ ......
โ โโโ ......
โโโval/
โ โโโ n01440764
โ โ โโโ ILSVRC2012_val_00000293.JPEG
โ โ โโโ ILSVRC2012_val_00002138.JPEG
โ โ โโโ ......
โ โโโ ......
Then run the following command:
torchrun --nnodes=1 --nproc_per_node=8 train.py --data-path={path to imagenet/train} --image-size=256 --model={model name} --epochs={iteration//5000} # fp32 Training
accelerate launch --mixed_precision fp16 train_accelerate.py --data-path {path to imagenet/train} --image-size=256 --model={model name} --epochs={iteration//5000} # fp16 Training
Training Feature Preparation (RECOMMENDED)
Following Fast-DiT, it is recommended to load VAE features directly for faster training. You don't need to download the enormous ImageNet dataset (> 100G); instead, a much smaller "VAE feature" dataset (~21G for ImageNet 256x256) is available here on HuggingFace and MindScope. Please do the following steps:
-
Download imagenet_feature.tar
-
Unzip the tar ball by running
tar -xf imagenet_feature.tar
imagenet_feature/
โโโ imagenet256_features/ # VAE features
โโโ imagenet256_labels/ # labels
- Append parser
--feature-path={path to imagenet_feature}
to the training command.
Inference
Weights Available
๐ฅ We released our models via HuggingFace and ModelScope. Please feel free to download them!
Sampling
Run the following command for parallel sampling:
torch --nnodes=1 --nproc_per_node=8 sample_ddp.py --ckpt={path to checkpoint} --image-size=256 --model={model name} --cfg-scale={cfg scale}
After sampling, an .npz file that contains 50000 images is automatically generated.
Metric Evaluation
We borrow the FID evaluation codes from here. Metrics including FIDs are calculated based on the .npz file. Before evaluation, make sure to download the reference batch for ImageNet 256x256. Then run the following command for metric evaluation:
python evaluator.py {path to reference batch} {path to generated .npz}
Future work (Stay Tuned!)
- Training code for U-DiTs
- Model weights
- ImageNet features from VAE for faster training
- Colab demos
- Outcomes from longer training
BibTex Formatted Citation
If you find this repo useful, please cite:
@misc{tian2024udits,
title={U-DiTs: Downsample Tokens in U-Shaped Diffusion Transformers},
author={Yuchuan Tian and Zhijun Tu and Hanting Chen and Jie Hu and Chao Xu and Yunhe Wang},
year={2024},
eprint={2405.02730},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Acknowledgement
We acknowledge the authors of the following repos:
https://github.com/facebookresearch/DiT (Codebase)
https://github.com/chuanyangjin/fast-DiT (FP16 training; Training on features)
https://github.com/openai/guided-diffusion (Metric evalutation)