Home

Awesome

Intelligent Grimm - Open-ended Visual Storytelling via Latent Diffusion Models (CVPR 2024)

This repository contains the official PyTorch implementation of StoryGen: https://arxiv.org/abs/2306.00973/

<div align="center"> <img src="./teaser.png"> </div>

Some Information

Project Page $\cdot$ Paper $\cdot$ Dataset $\cdot$ Checkpoint

Requirements

A suitable conda environment named storygen can be created and activated with:

conda env create -f environment.yaml
conda activate storygen

Meta Data Preparation

Data from YouTube

We provide the metadata of our StorySalon dataset in ./data/metadata.json. It includes the id, name, url, duration and the keyframe list after filtering the videos.

To download these videos, we recommend to use youtube-dl via:

youtube-dl --write-auto-sub -o 'file\%(title)s.%(ext)s' -f 135 [url]

The keyframes extracted with the following data processing pipeline (step 1) can be filtered according to the keyframe list provided in the metadata to avoid manual selection.

The corresponding masks, story-level description and visual description can be extracted with the following data processing pipeline or downloaded from here.

Data from Open-source Libraries

For the open-source PDF data, you can directly download the frames, corresponding masks, description and narrative from StorySalon.

Data Processing Pipeline

The data processing pipeline includes several necessary steps:

The keyframes and their corresponding subtitles can be extracted via:

python ./data_process/extract.py

The duplicate frames can be detected and removed via:

CUDA_VISIBLE_DEVICES=0 python ./data_process/dup_remove.py

The text, people and headshots can be segmented, and the frames that only contain real people are then removed via:

python ./data_process/yolov7/human_ocr_mask.py

The text, headshots and real hands in the frames can be inpainted with SDM-Inpainting, according to the segmentation mask via:

CUDA_VISIBLE_DEVICES=0 python ./data_process/SDM/inpaint.py

Besides, we also provide the code to get story-level paired image-text samples. We can align the subtitles with visual frames by using Dynamic Time Warping(DTW) algorithm via:

CUDA_VISIBLE_DEVICES=0 python ./data_process/align.py

(Optional) You can use TextBind or MiniGPT-v2 to obtain the caption of each image via:

CUDA_VISIBLE_DEVICES=0 python ./data_process/TextBind/main_caption.py
CUDA_VISIBLE_DEVICES=0 python ./data_process/MiniGPT-v2/main_caption.py

(Discarded) Previous method: You can also use ChatCaptioner to obtain the caption of each image via:

CUDA_VISIBLE_DEVICES=0 python ./data_process/ChatCaptioner/main_caption.py

For a more detailed introduction to the data processing pipeline, please refer to ./data_process/README.md and our paper.

Training

Before training, please download pre-trained StableDiffusion-1.5 from SDM (including vae, scheduler, tokenizer and unet). Then, all the pre-trained checkpoints should be placed into the corresponding location in the folder ./ckpt/stable-diffusion-v1-5/

For Stage 1, pre-train the self-attention layers in SDM for StyleTransfer via:

CUDA_VISIBLE_DEVICES=0 accelerate launch train_StorySalon_stage1.py

For Stage 2, train the Visual-Language Context Module via:

CUDA_VISIBLE_DEVICES=0 accelerate launch train_StorySalon_stage2.py

For replicating the experiments on MS-COCO, train via:

CUDA_VISIBLE_DEVICES=0 accelerate launch train_COCO.py

If you have multiple GPUs to accelerate the training process, you can use:

CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --multi_gpu train_StorySalon_stage2.py

Inference

CUDA_VISIBLE_DEVICES=0 accelerate launch inference.py

About the Testset

Regarding the test set in our StorySalon dataset, some researchers have reported that they are unable to obtain corresponding test data through our data processing pipeline due to certain YouTube videos being taken down.

To address this issue, we have provided pre-processed test datasets for download via Google Drive or HuggingFace, which can be used for testing and benchmarking purposes.

In consideration of potential copyright risks associated with YouTube videos, we have not directly included the original video content.

Please feel free to contact us if you have any questions regarding this part.

TODO

License

The code and checkpoints in this repository are under MIT license. The open-source books in the StorySalon dataset come from multiple online open-source libraries (please refer to the Appendix of our paper for more details), and they are all under CC-BY 4.0 license. It should be noted that, for the data extracted from the video data, we only provide YouTube URLs and data processing pipelines, if you wish to use them for commercial purposes, we recommend that you obey the relevant regulations of YouTube.

Citation

If you use this code for your research or project, please cite:

  @InProceedings{Liu_2024_CVPR,
        author    = {Liu, Chang and Wu, Haoning and Zhong, Yujie and Zhang, Xiaoyun and Wang, Yanfeng and Xie, Weidi},
        title     = {Intelligent Grimm - Open-ended Visual Storytelling via Latent Diffusion Models},
        booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
        month     = {June},
        year      = {2024},
        pages     = {6190-6200}
  }

Acknowledgements

Many thanks to the code bases from diffusers and SimpleSDM.

Contact

If you have any questions, please feel free to contact haoningwu3639@gmail.com or liuchang666@sjtu.edu.cn.