Home

Awesome

Simple at-scale training of Stable Diffusion

Purpose

Our goal for this repo is two-fold:

  1. Provide a transparent, simple implementation of which supports large-scale stable diffusion training for research purposes
  2. Integration with existing diffusion APIs

We accomplish this through minimal scaffolding code and utilization of the diffusers (github.com/huggingface/diffusers) repo. In the future, we will move simple elements of the diffusers repo (e.g. the noise scheduler) into this one for better clarity and easier extensibility. We also support the WebDataset format (github.com/webdataset/webdataset) to allow for training on large-scale datasets.

Usage

Requirements

First install the requirements under Python 3.9 in a virtualenv:

pip install -r requirements.txt

You might need to install xformers directly from their GitHub repo (github.com/facebookresearch/xformers). Further, we currently support UNet from-scratch training only. The rest of the components must be pretrained. You can download a compatible model using the following command:

python scripts/download_diffusers.py -o out_pipeline_path -m model_id

You can get various model_ids from the runway-ml/stabilityai/CompVis organizations on huggingface. A useful one is "stabilityai/stable-diffusion-2-base".

Finally, you will need to set up your dataset in the WebDataset format with image (jpg, png, webp) and text (txt) fields corresponding to images and captions respectively. You can also include a json field if you want to apply metadata filters (described below). See the img2dataset project at https://github.com/rom1504/img2dataset for details on how to do this for large-scale datasets.

Configuration

We use OmegaConf to handle commandline configuration. See configs/template_config.yaml for a template version. This is reproduced here for explanation.

wandb:
    api_key: ???
    # entity: ??? # necessary for a project group

system:
    gradient_accumulation: 1
    batch_size: 32
    workers: 6
    dist_backend: ${distributed.dist_backend}
    dist_url: ${distributed.dist_url}

distributed:
    dist_backend: 'nccl'
    dist_url: 'env://'

experiment:
    log_dir: ???
    name: ???
    project: ??? 
    num_examples_to_see: ???

    save_only_most_recent: True
    save_every: 2000
    requeue: True

optimizer:
    name: adamw
    params:
        learning_rate: 0.0001
        beta1: 0.9
        beta2: 0.98 
        weight_decay: 0.01
        epsilon: 0.00000001

model:
    vae:
        pretrained: ???

    text_encoder:
        pretrained: ???

    tokenizer:
        pretrained: ???
    
    noise_scheduler_training:
        pretrained: ???

    noise_scheduler_inference:
        pretrained: ???

    unet:
        target: UNet2DConditionModel
        params:
            act_fn: "silu"
            attention_head_dim: 8
            block_out_channels: [320, 640, 1280, 1280]
            center_input_sample: False
            cross_attention_dim: 768
            down_block_types: ["CrossAttnDownBlock2D","CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
            downsample_padding: 1
            flip_sin_to_cos: true
            freq_shift: 0
            in_channels: 4
            layers_per_block: 2
            mid_block_scale_factor: 1
            norm_eps: 1e-05
            norm_num_groups: 32
            out_channels: 4
            sample_size: 32
            up_block_types: [
                "UpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D"
            ]
        
    use_ema: False
    mixed_precision: bf16
    gradient_checkpointing: True
    xformers: True

dataset:
    type: WebDataset
    params: 
        path: ???
        batch_size: ${system.batch_size}
        workers: ${system.workers}
        num_examples_to_see: ${experiment.num_examples_to_see}
        resolution: 512

        filters: # optional metadata filters
            # original_width: ">=512"
            # original_height: ">=512"
            pwatermark: "<=0.5"


lr_scheduler:
    scheduler: "ConstantWithWarmup"
    params:
        learning_rate: ${optimizer.params.learning_rate}
        warmup_length: 500

Arguments with ??? are required.

Running a large-scale experiment

We use SLURM to support multi-node training. Below is a general template for an SBATCH script. This will likely need to be modified to work on your individual cluster. The flags in configs/batch_files/multinode-template.sbatch are necessary for running on an AWS SLURM cluster.

#!/bin/bash
#SBATCH --job-name=Job_Name
#SBATCH --partition=Partition_Name
#SBATCH --account=laion
#SBATCH --nodes=16
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=6
#SBATCH --time=96:00:00
#SBATCH --chdir=/path/to/sd-training
#SBATCH --output=/stdout/path
#SBATCH --error=/stderr/path
#SBATCH --requeue
#SBATCH --exclusive


export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=12802 # This is an arbitrary choice, change if this port is taken

srun --cpu_bind=v --accel-bind=gn \
    python -u train.py \
    config=path/to/yaml/config

--ntasks-per-node should equal the number of GPUs per node.

With OmegaConf, commandline overrides are done in dot-notation format. E.g. if you want to override the dataset path, you would use the command python -u train.py config=path/to/config dataset.params.path=path/to/dataset. If you're using the WebDataset type, this path also accepts pipes to AWS sharded datasets.

Local Fine-tuning example

This requires at least one GPU with >=24 GB of VRAM (e.g. a 3090).

Below is an example config:

wandb:
    api_key: ???

system:
    gradient_accumulation: 1
    batch_size: 2
    workers: 6
    dist_backend: ${distributed.dist_backend}
    dist_url: ${distributed.dist_url}

distributed:
    dist_backend: 'nccl'
    dist_url: 'env://'

experiment:
    log_dir: logs
    name: "finetuning"
    project: "diffusion"

    num_examples_to_see: ???

    eval_caption_file: ???
    num_eval_images: 1000
    save_every: 1000
    requeue: True

optimizer:
    name: adamw
    params:
        learning_rate: 0.0001
        beta1: 0.9
        beta2: 0.98 # changed from initial sd value for training stability
        weight_decay: 0.01
        epsilon: 0.00000001

model:
    pretrained: ???
    use_ema: False
    mixed_precision: bf16
    gradient_checkpointing: True
    xformers: True

dataset:
    type: WebDataset
    params: 
        path: ???
        batch_size: ${system.batch_size}
        workers: ${system.workers}
        num_examples_to_see: ${experiment.num_examples_to_see}
        resolution: 512

lr_scheduler:
    scheduler: "ConstantWithWarmup"
    params:
        learning_rate: ${optimizer.params.learning_rate}
        warmup_length: 500

Fill in every field with a ???, explanations for these variables can be found in the previous section and save this config. Next, you can run this experiment using torchrun (a torch specific version of mpirun), as below:

torchrun --nproc_per_node "# of gpus" -m train config=path/to/config

We currently support datasets in the webdataset or parquet formats. These datasets must have text (txt) and image fields (jpg, png, webp). We also support classification datasets with a text field provided by (cls). In this context you would also need to provide the variable dataset.params.idx_to_prompt, which is a json mapping from cls to prompt. See scripts/metadata/imagenet_idx_to_prompt.json for examples of this. In the future we will support the CSV dataset format as well.

Inference

Given our integration with the diffusers library, you can use their API for generation with any saved pipeline. Your latest model will be saved at experiment.log_dir / experiment.name / current_pipeline.

Example inference code snippet:

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch

pt_path = "path/to/pretrained/pipeline"

scheduler = EulerDiscreteScheduler.from_pretrained(pt_path, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(pt_path, scheduler=scheduler, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]  
    
image.save("output.png")

Examples

Below are some random (non-cherry picked) examples from training from scratch on ImageNet. See configs/templates/imagenet.yaml for the exact setup of this experiment.

ImageNet Examples

More examples from a large-scale training run to come. See configs/templates for stage 1 and stage 2 of training SD v2 base. We also include a script scripts/generate_examples.py to help for large-scale generation of images from a prompt list.

Performance and System Requirements

Mixed precision currently only supports bfloat16, which requires 30xx NVIDIA GPUs or A40s/100s. With all optimizations enabled (gradient checkpointing, mixed precision, and xformers) and EMA, we achieve the following throughputs (including distributed training overhead) on 80GB A100s through an AWS SLURM cluster:

Note that aggressive WebDataset filtering, throughput decreases due to the data loading bottleneck. Also note that GPU memory (~60%) is not saturated in either of these benchmarks.

We use the following architecture parameters for this performance benchmark:

  unet:
      target: diffusers.UNet2DConditionModel
      params:
          act_fn: "silu"
          attention_head_dim: 8
          block_out_channels: [320, 640, 1280, 1280]
          center_input_sample: False
          cross_attention_dim: 768
          down_block_types: ["CrossAttnDownBlock2D","CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
          downsample_padding: 1
          flip_sin_to_cos: true
          freq_shift: 0
          in_channels: 4
          layers_per_block: 2
          mid_block_scale_factor: 1
          norm_eps: 1e-05
          norm_num_groups: 32
          out_channels: 4
          sample_size: 32
          up_block_types: [
              "UpBlock2D",
              "CrossAttnUpBlock2D",
              "CrossAttnUpBlock2D",
              "CrossAttnUpBlock2D"
          ]

Recommendations for Large-Scale Training

Future features

In this order:

Contributors

The current version of the repo was authored by Vivek Ramanujan. Thank you to Robin Rombach, Rom Beaumont, and Jon Hayase for feedback and assistance. If you would like to contribute, open a PR/issue!

Acknowledgments

Thank you to Stability for providing the resources to test this code. A further thanks to the excellent open-source repos this one depends on, including diffusers, xformers, webdataset, and of course pytorch.