Awesome
Energy-Based Cross-Attention
[NeurIPS 2023] This repository is the official implementation of Energy-Based Cross Attention for Bayesian Context Update in Text-to-Image Diffusion Models .<br> Geon Yeong Park*, Jeongsol Kim*, Beomsu Kim, Sang Wan Lee, Jong Chul Ye
Abstract
Despite the remarkable performance of text-to-image diffusion models in image generation tasks, recent studies have raised the issue that generated images sometimes cannot capture the intended semantic contents of the text prompts, which phenomenon is often called semantic misalignment. To address this, here we present a novel energy-based model (EBM) framework. Specifically, we first formulate EBMs of latent image representations and text embeddings in each cross-attention layer of the denoising autoencoder. Then, we obtain the gradient of the log posterior of context vectors, which can be updated and transferred to the subsequent cross-attention layer, thereby implicitly minimizing a nested hierarchy of energy functions. Our latent EBMs further allow zero-shot compositional generation as a linear combination of cross-attention outputs from different contexts. Using extensive experiments, we demonstrate that the proposed method is highly effective in handling various image generation tasks, including multi-concept generation, text-guided image inpainting, and real and synthetic image editing.
- Latest update: 1 April, 2024
Prerequisites
- python 3.9
- diffusers 0.16.1
- pytorch 2.0.1
- CUDA 11.7
It is okay to use lower version of CUDA with proper pytorch version. We highly recommend to use a conda environment as below.
Getting started
1. Clone the repository
git clone https://github.com/jeongsol-kim/energy-attention.git
cd energy_attention
2. Set environment
conda env create --name EBCA --file simple_env.yaml
conda activate EBCA
Main tasks
Overview
This repo follows the style of diffusers. Specifically, the core modules
are consists of modules/models
, modules/pipelines
, and modules/utils
.
-
models
: This is a self-contained folder forclass EnergyUNet2DConditionModel(UNet2DConditionModel)
which replaces cross-attention in
UNet2DConditionModel
with the proposed Energy-based Cross-attention (EBCA). -
pipelines
: Each pipeline corresponds to a specific task, e.g.energy_realedit_stable_diffusion.py
for real-image editing. The denoiser usesEnergyUNet2DConditionModel
as their neural architecture.Stable_diffusion_repaint.py
pipeline is included for inheritance and comes from diffusers/examples/community -
utils
:attention_hook.py
: contains hook functions for monitoring the progress of energy values.gamma_scheduler.py
: contains generalized schedulers for hyperparameters in BCU, CACAO, etc.
We provide four exemplary main scripts, i.e realedit_txt2img.py
. These scripts may share some common options as follow:
--gamma_attn, --gamma_norm
: $\gamma_{attn}, \gamma_{norm}$ in BCU.--gamma_tau
: $\tau$ in learning rate scheduling (either for both main and editorial contexts. Please see appendix C for more details).--alpha
: $\alpha_{s}$ in CACAO.--alpha_tau
is similar togamma_tau
. Usually both $\tau$ are set in same value.--debug
: The results and configurations are saved in weight & biases if you are not debugging. Set this option if you don't want w&b logging.
More details per each task are provided in below.
1. Real-image editing
Cat → Dog
python realedit_txt2img.py --gamma_attn 0. --gamma_norm 0. --img_file assets/samples/realedit/cat_1.jpg \
--editing_prompt dog cat --editing_direction 1 0 --alpha 0.75 0.65 --alpha_tau 0.5 0.5 --gamma_attn_compose 0.0006 0.0005 --gamma_norm_compose 0.0006 0.0005 --gamma_tau 0.5 0.5
Horse → Zebra
python realedit_txt2img.py --gamma_attn 0. --gamma_norm 0. --img_file assets/samples/realedit/horse_1.jpg \
--editing_prompt "zebra" "brown horse" --editing_direction 1 0 --alpha 0.6 0.5 --alpha_tau 0.3 0.3 --gamma_attn_compose 0.0005 0.0004 --gamma_norm_compose 0.0005 0.0004 --gamma_tau 0.3 0.3
AFHQ
python realedit_txt2img.py --img_file assets/samples/realedit/afhq_1.jpg --editing_prompt "a goat, a photography of a goat" \
--alpha 0.6 --alpha_tau 0.35 --gamma_attn_compose 0.001 --gamma_norm_compose 0.001 --gamma_tau 0.3 --editing_direction 1 --seed 0
- BCU is not applied to the main prompt for the sake of img-to-img translation (more discussions in appendix C).
--editing_direction
implies whether the--editing_prompt
, e.g. dog, cat, corresponds to conjunction (1) or negation (0) composition.- Recommend to set
--gamma_{attn, norm}_compose
in a scale of $1e-4$. - While the hyper-parameters, e.g. $\gamma, \alpha$, are set in common for every images, each hyper-parameter can be individually fine-tuned for a better performance. For example, smaller
--alpha_tau
leads to more drastic changes as CACAO starts to be applied from early time steps.
2. Synthetic-image editing
Stylization
python synedit_txt2img.py --gamma_attn 0. --gamma_norm 0. --prompt "a house at a lake" --editing_prompt "Water colors, Watercolor painting, Watercolor" \
--alpha 1.1 --alpha_tau 0.2 --seed 12 --gamma_attn_compose 0.0004 --gamma_norm_compose 0.0004 --gamma_tau 0.2
Editing
python synedit_txt2img.py --gamma_attn 0.01 --gamma_norm 0.01 --prompt "a castle next to a river" --seed 48 \\
--editing_prompt "monet painting, impression, sunrise" "boat on a river, boat" --editing_direction 1 1 \\
--alpha 1.3 1.3 --alpha_tau 0.2 0.2 --gamma_attn_compose 0. 0. --gamma_norm_compose 0. 0. --gamma_tau 0. 0.
--editing_direction
implies whether the--editing_prompt
, e.g. monet, boat, corresponds to conjunction (1) or negation (0) composition.- We refer some of the prompt examples from SEGA: Instructing Diffusion using Semantic Dimensions.
- While BCU is not applied to editorial prompts in this example, i.e.
--gamma_{attn, norm}_compose=0
, they can be also readily applied to editorial prompts as well. Here, BCU is applied to the main prompt in contrast to the real-image editing, i.e. img-to-img translation, to preserve the intended semantic contents of main prompts, a castle next to a river.
3. Multi-concept generation
python inference_txt2img.py --prompt "A lion with a crown" --gamma_attn 0.01 --gamma_norm 0.02 --seed 1
- We recommend to use $\gamma_{attn}, \gamma_{attn} \in {0.01, 0.02}$.
- Also, we recommend to use
--token_indices
and--token_upweight
for token-wise step size tuning, i.e. $\gamma_{attn, s}$, given a $s$-th token. For example,
python inference_txt2img.py --prompt "A cat wearing a shirt" --gamma_attn 0.01 --gamma_norm 0.02 --seed 1 --token_upweight 2.5 --token_indices 5
4. Text-guided inpainting
python inpaint_txt2img.py --gamma_attn 0.025 --gamma_norm 0.025 --prompt "teddy bear" --img_file assets/samples/inpaint/starry_night_512.png --mask_file assets/samples/inpaint/starry_night_512_mask.png
- We recommend to use $\gamma_{attn}, \gamma_{attn} \in {0.01, 0.025}$.
- Stable Inpaint is a default backbone model. It can be switched to a Stable Repaint with
--repaint
option.
5. Citation
If you find our work interesting, please cite our paper.
@article{park2024energy,
title={Energy-based cross attention for bayesian context update in text-to-image diffusion models},
author={Park, Geon Yeong and Kim, Jeongsol and Kim, Beomsu and Lee, Sang Wan and Ye, Jong Chul},
journal={Advances in Neural Information Processing Systems},
volume={36},
year={2024}
}