Awesome
Prompt-Diffusion: In-Context Learning Unlocked for Diffusion Models
Project Page | Paper
In-Context Learning Unlocked for Diffusion Models<br> Zhendong Wang, Yifan Jiang, Yadong Lu, Yelong Shen, Pengcheng He, Weizhu Chen, Zhangyang Wang and Mingyuan Zhou <br>
Abstract: We present Prompt Diffusion, a framework for enabling in-context learning in diffusion-based generative models. Given a pair of task-specific example images, such as depth from/to image and scribble from/to image, and a text guidance, our model automatically understands the underlying task and performs the same task on a new query image following the text guidance. To achieve this, we propose a vision-language prompt that can model a wide range of vision-language tasks and a diffusion model that takes it as input. The diffusion model is trained jointly on six different tasks using these prompts. The resulting Prompt Diffusion model becomes the first diffusion-based vision-language foundation model capable of in-context learning. It demonstrates high-quality in-context generation for the trained tasks and effectively generalizes to new, unseen vision tasks using their respective prompts. Our model also shows compelling text-guided image editing results. Our framework aims to facilitate research into in-context learning for computer vision, with code publicly available here.
Prompt Diffusion
Hugging Face Diffusers Suport
We thank the contribution of iczaw. Now Prompt-Diffusion is supported through the diffusers package. Following the guidance code below for a quick try:
import torch
from diffusers import DDIMScheduler, UniPCMultistepScheduler
from diffusers.utils import load_image
from promptdiffusioncontrolnet import PromptDiffusionControlNetModel
from pipeline_prompt_diffusion import PromptDiffusionPipeline
from PIL import ImageOps
image_a = ImageOps.invert(load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house_line.png?raw=true"))
image_b = load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house.png?raw=true")
query = ImageOps.invert(load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/new_01.png?raw=true"))
# load prompt diffusion controlnet and prompt diffusion
controlnet = PromptDiffusionControlNetModel.from_pretrained("zhendongw/prompt-diffusion-diffusers", subfolder="controlnet", torch_dtype=torch.float16)
pipe = PromptDiffusionPipeline.from_pretrained("zhendongw/prompt-diffusion-diffusers", controlnet=controlnet).to(torch_dtype=torch.float16).to('cuda')
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed
# pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_model_cpu_offload()
# generate image
generator = torch.manual_seed(2023)
image = pipe("a tortoise", num_inference_steps=50, generator=generator, image_pair=[image_a,image_b], image=query).images[0]
image.save('./test.png')
Prepare Dataset
We use the public dataset proposed by InstructPix2Pix as our base dataset,
which consists of around 310k image-caption pairs. Furthermore, we apply the ControlNet annotators
to collect image conditions such as HED/Depth/Segmentation maps of images. The code for collecting image conditions is provided in annotate_data.py
.
Training
Training a Prompt Diffusion is as easy as follows,
python tool_add_control.py 'path to your stable diffusion checkpoint, e.g., /.../v1-5-pruned-emaonly.ckpt' ./models/control_sd15_ini.ckpt
python train.py --name 'experiment name' --gpus=8 --num_nodes=1 \
--logdir 'your logdir path' \
--data_config './models/dataset.yaml' --base './models/cldm_v15.yaml' \
--sd_locked
We also provide the job script in scripts/train_v1-5.sh
for an easy run.
Run Prompt Diffusion from our Checkpoints
We release the model checkpoints trained by us at our Huggingface Page and
the quick access for downloading is here.
We provide a jupyter notebook
run_prompt_diffusion.ipynb
for trying the inference code of Prompt Diffusion. We also provide a few images to try on in the folder
images_to_try
. We are preparing a demo based on Gradio and will release the demo soon.
Results
Multi-Task Learning
Generalization to New Tasks
Image Editing Ability
More Examples
Citation
@article{wang2023promptdiffusion,
title = {In-Context Learning Unlocked for Diffusion Models},
author = {Wang, Zhendong and Jiang, Yifan and Lu, Yadong and Shen, Yelong and He, Pengcheng and Chen, Weizhu and Wang, Zhangyang and Zhou, Mingyuan},
journal = {arXiv preprint arXiv:2305.01115},
year = {2023},
url = {https://arxiv.org/abs/2305.01115}
}
Acknowledgements
We thank Brooks et al. for sharing the dataset for finetuning Stable Diffusion. We also thank Lvmin Zhang and Maneesh Agrawala for providing the awesome code base ControlNet.