Awesome
Dynamic Prompt Optimization for Text-to-Image Generation (CVPR24)
<a href='https://arxiv.org/abs/2404.04095'><img src='https://img.shields.io/badge/ArXiv-2404.04095-red'></a>
Abstract
<details><summary>CLICK for the full abstract</summary></details>Text-to-image generative models, specifically those based on diffusion models like Imagen and Stable Diffusion, have made substantial advancements. Recently, there has been a surge of interest in the delicate refinement of text prompts. Users assign weights or alter the injection time steps of certain words in the text prompts to improve the quality of generated images. However, the success of fine-control prompts depends on the accuracy of the text prompts and the careful selection of weights and time steps, which requires significant manual intervention. To address this, we introduce the Prompt Auto-Editing (PAE) method. Besides refining the original prompts for image generation, we further employ an online reinforcement learning strategy to explore the weights and injection time steps of each word, leading to the dynamic fine-control prompts. The reward function during training encourages the model to consider aesthetic score, semantic consistency, and user preferences. Experimental results demonstrate that our proposed method effectively improves the original prompts, generating visually more appealing images while maintaining semantic alignment.
Setup Environment
conda create -n PAE python=3.8
conda activate PAE
# First, install PyTorch 1.12.1 (or later) and torchvision, and then install other dependencies.
pip install -r requirements.txt
xformers
is recommended for A800 GPU to save memory and running time.
Our environment is similar to minChatGPT official. You may check them for more details.
Weight Download
To use our model, please download the following weight file and follow the instructions for setup.
cd ckpt/PAE
bash download.sh
Demo
# generate DF-prompts only
python demo_DF_prompt.py
# generate the images according to DF-prompts only
# ["a red horse on the yellow grass, anime style",
# "a red horse on the yellow grass, [anime:0-1:1.5] style",
# "a red horse on the yellow grass, detailed",
# "a red horse on the yellow grass, [detailed:0.85-1:1]"]
python demo_Image.py
<img src="docs/DF_example.png" width="100%"/>
Data
Training Data
You can download the training data from the following link: Download
Evaluate Data
The evaluation datasets are provided as NumPy files and are stored in the data
directory of this repository. The following datasets are available:
COCO_test_1k.npy
: This is a test dataset from the COCO dataset, consisting of 1,000 prompts.diffusiondb_test_1k.npy
: This test dataset is from the DiffusionDB, and it contains 1,000 test prompts.lexica_test_1k.npy
: The Lexica test dataset comprises 1,000 prompts for evaluation.
Train
<img src="docs/method.png" width="100%"/>Stage 1:
To execute the training script train_stage_1.py
, use the following command line arguments:
-n
: Experiment name.-card
: Specify the GPU to use withcuda:{card}
format.-b
: Batch size.-t
: Number of iterations.
Example command:
python train_stage_1.py -n your_experiment_name -card 0 -b 64 -t 1e6
Stage 2:
To run the training script train_stage_2.py
, use the following command line options:
-
-n
: Experiment name. -
-card
: Specifies the GPU card to use. -
-e
: Number of epochs. -
-b
: Batch size. -
-a
: Path to the policy model to load, which is the model trained during the first stage. -
-c
: Path to the value model to load, which is also the model trained during the first stage. -
Output: The trained model after a specified number of steps will be saved under the
runs
directory.
Example command:
python -u train_stage_2.py -n your_experiment_name --card 0 -e 500 -b 32 -a your_ckpt_path -c your_ckpt_path
Evaluation
CMMD (Rethinking FID: Towards a Better Evaluation Metric for Image Generation, official)
1. Install other dependencies
pip install git+https://github.com/google-research/scenic.git
cd cmmd
pip install -r requirements.txt
2. Download the results of PAE and Promptist and reference images in COCO dataset and put them in directory named "save".
your_project_root/
├── save/
├── coco_PAE/
├── coco_promptist/
├── coco_image_data/
...
Type | Name | Download Link |
---|---|---|
eval images | coco_PAE | Download |
eval images | coco_promptist | Download |
reference images | coco_image_data | Download |
3. CMMD calculation
Example command:
# evaluate PAE
python -m cmmd.main save/coco_image_data save/coco_PAE --batch_size=32 --max_count=1000
# evaluate promptist
python -m cmmd.main save/coco_image_data save/coco_promptist --batch_size=32 --max_count=1000
Other metrics
To generate dynamic prompts and their corresponding images, and subsequently compute various metrics, use the script evaluate.py
as follows:
python evaluate.py
Command Line Options:
--card
: Specifies the GPU card to use.--data
: The dataset to use for evaluation.--save
: The path to save the generated images.--ckpt
: The path to the trained policy model checkpoint.
Example command:
python evaluate.py --card 0 --data "coco" --save "save/coco/" --ckpt "ckpt/PAE/actor_step3000.pt"
Results
<img src="docs/PAE_result1.png" width="100%"/>📍 Citation
@inproceedings{mo2024dynamic,
title={Dynamic Prompt Optimizing for Text-to-Image Generation},
author={Mo, Wenyi and Zhang, Tianyu and Bai, Yalong and Su, Bing and Wen, Ji-Rong and Yang, Qing},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={26627--26636},
year={2024}
}