Awesome
Exploring Discrete Diffusion Models for Image Captioning.
Official implementation for the paper "Exploring Discrete Diffusion Models for Image Captioning"
Training prerequisites
You can use docker. Also, you can create environment and install dependencies:
conda env create -f environment.yml
or
bash install_req.sh
or
pip install -r requirements.txt
COCO training
Download train_captions.
Download training images and validation images and unzip (We use Karpathy et el. split).
Download oscar_split_ViT-B_32_train_512.pkl in ./data/coco/
Microsoft COCO
│MSCOCO_Caption/
├──annotations/
│ ├── captions_train2014.json
│ ├── captions_val2014.json
├──train2014/
│ ├── COCO_train2014_000000000009.jpg
│ ├── ......
├──val2014/
│ ├── COCO_val2014_000000000042.jpg
│ ├── ......
Prepare evaluation
Change the work directory and set up the code of evaluation :
cd ./captioneval/coco_caption
bash ./get_stanford_models.sh
Run
MKL_THREADING_LAYER=GPU python -m torch.distributed.launch --nproc_per_node 8 train.py --out_dir /results_diff --tag caption_diff_vitb16
If you want train the model with trainable clip, you can use the command:
MKL_THREADING_LAYER=GPU python -m torch.distributed.launch --nproc_per_node 8 train_tclip.py --out_dir /results_diff --tag caption_diff_vitb16
Please noting that we detach the gradients of [CLS] tokens during the training process of clip model. Because We observe that when the image encoder (clip) is trainable, the gradient backward of [CLS] tokens will damage the training of image encoder (clip).
Citation
If you use this code for your research, please cite:
@article{zhu2022exploring,
title={Exploring Discrete Diffusion Models for Image Captioning},
author={Zhu, Zixin and Wei, Yixuan and Wang, Jianfeng and Gan, Zhe and Zhang, Zheng and Wang, Le and Hua, Gang and Wang, Lijuan and Liu, Zicheng and Hu, Han},
journal={arXiv preprint arXiv:2211.11694},
year={2022}
}
Acknowledgments
This repository is heavily based on CLIP, CLIP_prefix_caption and Hugging-faces repositories. For training we used the data of COCO dataset.