Home

Awesome

Reinforcement Learning for Consistency Models (RLCM)

This is the official implementation of the paper RL for Consistency Models: Faster Reward Guided Text-to-Image Generation. We support low rank adaptation (LoRA) for finetuning a latent consistency model (LCM). Much of this code was inspired by the repository for Denoising Diffusion Policy Optimization (DDPO).

UPDATE Now with REBEL suppot! Check out the rebel paper for more information, or just turn it on at training.algorithm of config.yaml. You may have to increase the learning rate (and decrease train batch size depending on your gpu), but otherwise all hyperparameters should be the same.

image

Installation

In order to install the required clone this repository and run the setup.py file:

git clone https://github.com/Owen-Oertell/rlcm.git
cd rlcm
pip install -e . 

you must have python >=3.10 installed.

RLCM Training

RLCM is run by navigating into the scripts folder and then running the main.py file using accelerate. By default, we using the compression task but other tasks can be used.

To run the aesthetic task for example, we can run the following command.

accelerate launch main.py task=aesthetic

There are four tasks available: prompt_image_alignment, aesthetic, compression, and incompression. For more discussion about the tasks, please see the paper.

RLCM Inference

We also provide a sample inference script once you have saved your models to disk. This script is located in the scripts folder and is called inference.py. You can run this script by running the following command (after editing it to point to your saved model and update if you changed the config):

python inference.py

Summary of Hyperparameters

Below is a summary of the hyperparameters that can be used to train. Each task has its own hyperparameters. The default ones are given in lcm_rl_pytorch/configs and can be overwritten by passing them as arguments to the main.py script.

Reproducing Results

We reproduce our results from the paper below. Please see the appendix for full information of hyperparameters and number of gpus used. At a high level however, we used 4 RTX a6000 gpus for each of the tasks where the prompt image alignment task was run with 3 gpus and 1 gpus for the server (from kevin's repo. Make sure to use to 13b parameter version of LLaVA otherwise we've experienced empty outputs).

image

Plots of performance by runtime measured by GPU hours. We report the runtime on four NVIDIA RTX A6000 across three random seeds and plot the mean and standard deviation. We observe that in all tasks RLCM noticeably reduces the training time while achieving comparable or better reward score performance.

image

Training curves for RLCM and DDPO by number of reward queries on compressibility, incompressibility, aesthetic, and prompt image alignment. We plot three random seeds for each algorithm and plot the mean and standard deviation across those seeds. RLCM seems to produce either comparable or better reward optimization performance across these tasks.

Citation

@misc{oertell2024rl,
      title={RL for Consistency Models: Faster Reward Guided Text-to-Image Generation}, 
      author={Owen Oertell and Jonathan D. Chang and Yiyi Zhang and Kianté Brantley and Wen Sun},
      year={2024},
      eprint={2404.03673},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}