Home

Awesome

ZigMa: A DiT-style Zigzag Mamba Diffusion Model (ECCV 2024)

Oral Talk in ICML 2024 Workshop on Long Context Foundation Models (LCFM)

ECCV 2024

This repository represents the official implementation of the paper titled "ZigMa: A DiT-style Zigzag Mamba Diffusion Model (ECCV 2024)".

Website Paper Hugging Face Model GitHub GitHub closed issues Twitter License visitors

Vincent Tao Hu, Stefan Andreas Baumann, Ming Gui, Olga Grebenkova, Pingchuan Ma, Johannes Fischer, BjΓΆrn Ommer

We present ZigMa, a scanning scheme that follows a zigzag pattern, considering both spatial continuity and parameter efficiency. We further adapt this scheme to video, separating the reasoning between spatial and temporal dimensions, thus achieving efficient parameter utilization. Our design allows for greater incorporation of inductive bias for non-1D data and improves parameter efficiency in diffusion models.

πŸŽ“ Citation

Please cite our paper:

@InProceedings{hu2024zigma,
      title={ZigMa: A DiT-style Zigzag Mamba Diffusion Model},
      author={Vincent Tao Hu and Stefan Andreas Baumann and Ming Gui and Olga Grebenkova and Pingchuan Ma and Johannes Fischer and BjΓΆrn Ommer},
      booktitle = {ECCV},
      year={2024}
}

:white_check_mark: Updates

landscape faceshq teaser

Quick Demo

from model_zigma import ZigMa

img_dim = 32
in_channels = 3

model = ZigMa(
in_channels=in_channels,
embed_dim=640,
depth=18,
img_dim=img_dim,
patch_size=1,
has_text=True,
d_context=768,
n_context_token=77,
device="cuda",
scan_type="zigzagN8",
use_pe=2,
)

x = torch.rand(10, in_channels, img_dim, img_dim).to("cuda")
t = torch.rand(10).to("cuda")
_context = torch.rand(10, 77, 768).to("cuda")
o = model(x, t, y=_context)
print(o.shape)

Improved Training Performance

In comparison to the original implementation, we implement a selection of training speed acceleration and memory saving features including gradient checkpointing

torch.compilegradient checkpointingtraining speedmemory
❌❌1.05 iters/sec18G
βŒβœ”0.93 steps/sec9G
βœ”βŒ1.8 iters/sec18G

torch.compiles is for indexing operation: here and here

πŸš€ Training

CelebaMM256

Sweep-2, 1GPU

accelerate launch  --num_processes 1 --num_machines 1  --mixed_precision fp16    train_acc.py  model=sweep2_b1  use_latent=1   data=celebamm256_uncond  ckpt_every=10_000 data.sample_fid_n=5_000 data.sample_fid_bs=4 data.sample_fid_every=10_000  data.batch_size=8   note=_ 

Zigzag-8, 1GPU

CUDA_VISIBLE_DEVICES=4 accelerate launch  --num_processes 1 --num_machines 1  --mixed_precision fp16  --main_process_ip 127.0.0.1 --main_process_port 8868  train_acc.py  model=zigzag8_b1  use_latent=1   data=celebamm256_uncond  ckpt_every=10_000 data.sample_fid_n=5_000 data.sample_fid_bs=4 data.sample_fid_every=10_000  data.batch_size=4   note=_ 

UCF101

Baseline, multi-GPU

CUDA_VISIBLE_DEVICES="0,1,2,3" accelerate launch  --num_processes 4 --num_machines 1 --multi_gpu --mixed_precision fp16  --main_process_ip 127.0.0.1 --main_process_port 8868  train_acc.py  model=3d_sweep2_b2  use_latent=1 data=ucf101  ckpt_every=10_000  data.sample_fid_n=20_0 data.sample_fid_bs=4 data.sample_fid_every=10_000  data.batch_size=4   note=_ 

Factorized 3D Zigzag: sst, multi-GPU

CUDA_VISIBLE_DEVICES="0,1,2,3" accelerate launch  --num_processes 4 --num_machines 1 --multi_gpu --mixed_precision fp16  --main_process_ip 127.0.0.1 --main_process_port 8868  train_acc.py  model=3d_zigzag8sst_b2  use_latent=1 data=ucf101  ckpt_every=10_000  data.sample_fid_n=20_0 data.sample_fid_bs=4 data.sample_fid_every=10_000  data.batch_size=4   note=_ 

πŸš€ Sampling

FacesHQ 1024

You can directly download the model in this repository. You also can download the model in python script:

from huggingface_hub import hf_hub_download

hf_hub_download(
        repo_id="taohu/zigma",
        filename="faceshq1024_0090000.pt",
        local_dir="./checkpoints",
    )

huggingface model repo

DatasetCheckingpointModeldata
faceshq1024.ptfaceshq1024_0090000.ptmodel=s1024_zigzag8_b2_olddata=facehq_1024
landscape1024landscape1024_0210000.ptmodel=s1024_zigzag8_b2_olddata=landscapehq_1024
Churches256churches256_0280000.ptmodel=zigzag8_b1_pe2data=churches256
Coco256zigzagN8_b1_pe2_coco14_bs48_0400000.ptmode=zigzag8_b1_pe2data=coco14 (31.0)

1GPU sampling

CUDA_VISIBLE_DEVICES="2" accelerate launch  --num_processes 1 --num_machines 1     sample_acc.py  model=s1024_zigzag8_b2_old  use_latent=1   data=facehq_1024  ckpt_every=10_000 data.sample_fid_n=5_000 data.sample_fid_bs=4 data.sample_fid_every=10_000  data.batch_size=8  sample_mode=ODE likelihood=0  num_fid_samples=5_000 sample_debug=0  ckpt=checkpoints/faceshq1024_0060000.pt  

The sampled images will be saved both on wandb (disable with use_wandb=False) and directory samples/

πŸ› οΈ Environment Preparation

cuda==11.8,python==3.11, torch==2.2.0, gcc==11.3(for SSM enviroment)

python=3.11 # support the torch.compile for the time being. https://github.com/pytorch/pytorch/issues/120233#issuecomment-2041472137

conda create -n zigma python=3.11
conda activate zigma
conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit
conda install pytorch torchvision  pytorch-cuda=11.8 -c pytorch -c nvidia
pip install  torchdiffeq  matplotlib h5py timm diffusers accelerate loguru blobfile ml_collections wandb
pip install hydra-core opencv-python torch-fidelity webdataset einops pytorch_lightning
pip install torchmetrics --upgrade
pip install opencv-python causal-conv1d
cd dis_causal_conv1d && pip install -e . && cd ..
cd dis_mamba && pip install -e . && cd ..
pip install moviepy imageio #wandb.Video() need it
pip install  scikit-learn --upgrade 
pip install transformers==4.36.2
pip install numpy-hilbert-curve # (optional) for generating the hilbert path
pip install av    # (optional)  to use the ucf101 frame extracting
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers  #for FDD metrics

Installing Mamba may cost a lot of effort. If you encounter problems, this issues in Mamba may be very helpful.

Create a file under the directory ./config/wandb/default.yaml:

key: YOUR_WANDB_KEY
entity: YOUR_ENTITY
project: YOUR_PROJECT_NAME

Q&A

πŸ“· Dataset Preparation

Due to privacy issue, we cannot share the dataset here, basically, we use MM-CelebA-HQ-Dataset from https://github.com/IIGROUP/MM-CelebA-HQ-Dataset, we organize into the format of webdataset to enable the scalable training in multi-gpu.

Webdataset Format:

The dataset we use include:

Trend

Star History Chart

🎫 License

This work is licensed under the Apache License, Version 2.0 (as defined in the LICENSE).

By downloading and using the code and model you agree to the terms in the LICENSE.

License