Awesome
HIVE
HIVE: Harnessing Human Feedback for Instructional Visual Editing
Shu Zhang*<sup>1</sup>, Xinyi Yang*<sup>1</sup>, Yihao Feng*<sup>1</sup>, Can Qin<sup>3</sup>, Chia-Chih Chen<sup>1</sup>, Ning Yu<sup>1</sup>, Zeyuan Chen<sup>1</sup>, Huan Wang<sup>1</sup>, Silvio Savarese<sup>1,2</sup>, Stefano Ermon<sup>2</sup>, Caiming Xiong<sup>1</sup>, and Ran Xu<sup>1</sup><br> <sup>1</sup>Salesforce AI, <sup>2</sup>Stanford University, <sup>3</sup>Northeastern University<br> *denotes equal contribution<br> arXiv 2023
paper | project page
<img src='imgs/results.png' width=700></pre>
This is a PyTorch implementation of HIVE: Harnessing Human Feedback for Instructional Visual Editing. The major part of the code follows InstructPix2Pix. In this repo, we have implemented both stable diffusion v1.5-base and stable diffusion v2.1-base as the backbone.
Updates
- 07/08/23: Training code and training data is public.:blush:
- 03/26/24: HIVE will appear in CVPR, 2024.:blush:
Usage
Preparation
First set-up the hive
enviroment and download the pretrianed model as below. This is only verified on CUDA 11.0 and CUDA 11.3 with NVIDIA A100 GPU.
conda env create -f environment.yaml
conda activate hive
bash scripts/download_checkpoints.sh
To fine-tune a stable diffusion model, you need to obtain the pre-trained stable diffusion models following their instructions. If you use SD-V1.5, you can download the huggingface weights HuggingFace SD 1.5. If you use SD-V2.1, the weights can be downloaded on HuggingFace SD 2.1. You can decide which version of checkpoint to use. We use v2-1_512-ema-pruned.ckpt
. Download the model to checkpoints/.
Data
We suggest to install Gcloud CLI following Gcloud download. To obtain both training and evaluation data, run
bash scripts/download_hive_data.sh
An alternative method is to directly download the data through Evaluation data and Evaluation instructions.
Step-1 Training
For SD v2.1, we run
python main.py --name step1 --base configs/train_v21_base.yaml --train --gpus 0,1,2,3,4,5,6,7
Inference
Samples can be obtained by running the command.
For SD v2.1, if we use the conditional reward, we run
python edit_cli_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" --ckpt checkpoints/hive_v2_rw_condition.ckpt \
--config configs/generate_v21_base.yaml
or run batch inference on our inference data:
python edit_cli_batch_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--jsonl_file data/test.jsonl --output_dir imgs/sdv21_rw_label/ --ckpt checkpoints/hive_v2_rw_condition.ckpt \
--config configs/generate_v21_base.yaml --image_dir data/evaluation/
For SD v2.1, if we use the weighted reward, we can run
python edit_cli.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" \
--ckpt checkpoints/hive_v2_rw.ckpt --config configs/generate_v21_base.yaml
or run batch inference on our inference data:
python edit_cli_batch.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--jsonl_file data/test.jsonl --output_dir imgs/sdv21/ --ckpt checkpoints/hive_v2_rw.ckpt \
--config configs/generate_v21_base.yaml --image_dir data/evaluation/
For SD v1.5, if we use the conditional reward, we can run
python edit_cli_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" \
--ckpt checkpoints/hive_rw_condition.ckpt --config configs/generate.yaml
or run batch inference on our inference data:
python edit_cli_batch_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--jsonl_file data/test.jsonl --output_dir imgs/sdv15_rw_label/ \
--ckpt checkpoints/hive_rw_condition.ckpt --config configs/generate.yaml \
--image_dir data/evaluation/
For SD v1.5, if we use the weighted reward, we run
python edit_cli.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 --input imgs/example1.jpg \
--output imgs/output.jpg --edit "move it to Mars" \
--ckpt checkpoints/hive_rw.ckpt --config configs/generate.yaml
or run batch inference on our inference data:
python edit_cli_batch.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
--jsonl_file data/test.jsonl --output_dir imgs/sdv15/ \
--ckpt checkpoints/hive_rw.ckpt --config configs/generate.yaml \
--image_dir data/evaluation/
Citation
@article{zhang2023hive,
title={HIVE: Harnessing Human Feedback for Instructional Visual Editing},
author={Zhang, Shu and Yang, Xinyi and Feng, Yihao and Qin, Can and Chen, Chia-Chih and Yu, Ning and Chen, Zeyuan and Wang, Huan and Savarese, Silvio and Ermon, Stefano and Xiong, Caiming and Xu, Ran},
journal={arXiv preprint arXiv:2303.09618},
year={2023}
}