Home

Awesome

Stepwise Diffusion Policy Optimization (SDPO)

This is a PyTorch implementation of Stepwise Diffusion Policy Optimization (SDPO) from our paper Aligning Few-Step Diffusion Models with Dense Reward Difference Learning.

Aligning text-to-image diffusion models with downstream objectives (e.g., aesthetic quality or user preferences) is essential for their practical applications. However, standard alignment methods often struggle with step generalization when directly applied to few-step diffusion models, leading to inconsistent performance across different denoising step scenarios. To address this, we introduce SDPO, which facilitates stepwise optimization of few-step diffusion models through dense reward difference learning, consistently exhibiting superior performance in reward-based alignment across all sampling steps.

SDPO

reward_curves

Installation

# Create a new conda environment
conda create -p sdpo python=3.10.12 -y

# Activate the newly created conda environment
conda activate sdpo

# Navigate to the project’s root directory (replace with the actual path)
cd /path/to/project

# Install the project dependencies
pip install -e .

Usage

We use accelerate to enable distributed training. Before running the code, ensure accelerate is properly configured for your system:

accelerate config

Use the following commands to run SDPO with different reward functions:

For detailed explanations of the hyperparameters, please refer to the following configuration files:

These files are pre-configured for training on 4 GPUs, each with at least 24GB of memory. If a hyperparameter is defined in both configuration files, the value in config/config_sdpo.py will take precedence.

Citation

If you find this work useful in your research, please consider citing:

@article{zhang2024sdpo,
  title={Aligning Few-Step Diffusion Models with Dense Reward Difference Learning},
  author={Ziyi Zhang and Li Shen and Sen Zhang and Deheng Ye and Yong Luo and Miaojing Shi and Bo Du and Dacheng Tao},
  journal={arXiv preprint arXiv:2411.11727},
  year={2024}
}

Acknowledgement