Awesome
Revisiting Multimodal Representation in Contrastive Learning: From Patch and Token Embeddings to Finite Discrete Tokens
This repository is a re-implementation of the paper Revisiting Multimodal Representation in Contrastive Learning: From Patch and Token Embeddings to Finite Discrete Tokens. It includes PyTorch code for pretraining and zero-shot evaluation.
<p align="center"><img src="figures/method.png" width="1000" height=""/></p>@inproceedings{chen2023revisiting,
title={Revisiting multimodal representation in contrastive learning: from patch and token embeddings to finite discrete tokens},
author={Chen, Yuxiao and Yuan, Jianbo and Tian, Yu and Geng, Shijie and Li, Xinyu and Zhou, Ding and Metaxas, Dimitris N and Yang, Hongxia},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={15095--15104},
year={2023}
}
Installation
Please run the following commands in your shell:
conda create -n fdt python=3.8
conda activate fdt
pip install -r requirements.txt
Dataset Preparation
Pre-training
This paper uses four publicly available datasets for pre-training, including YFCC-15M V2, Conceptual Captions (CC3M), Conceptual 12M (CC12M), and LAION115M.
In this repository, we pre-train the model using the CC3M dataset. The data preparation process is outlined as follows:
- Download the dataset using the script provided by img2dataset. The downloaded dataset is saved as the webdataset format. There are a series of .tar files. Each file is a ``shard'', and it has two files for each training sample, one for the image and one for the corresponding text.
- We further process the shards using tarproc to make that each shard contains 1,000 samples:
git clone https://github.com/tmbdev-archive/tarproc #install tarproc
python data_process/wds/process_download_data.py \
--src_fold ${path/to/downloaded/data} \
--res_fold ${path/to/save/results} \
--tarproc_fold ${path/to/tarpoc/folder}
Evaluation
ImageNet
- Create a folder named
ImageNet-1K/
under$DATA
. - Download the validation set from the official website and extract it.
- Download the label file from DeClip - val_official.json
- Download the list of image paths - val_datalist.pkl
ImageNet-1K/
|–– val/ # contains 1,000 folders like n04357314, n03976657, etc.
|-- val_official.json
|-- val_datalist.pkl
MSCOCO
- Create a folder named
coco/
under$DATA
. - Download the validation set from official website and extract it under '$DATA/coco/'.
- Download the list of image names and their captions- testall.pkl, and place it under '$DATA/coco/', resulting in the following directory structure:
coco/
|–– val2014/ #
|-- testall.pkl
Pre-training
- Modify the below parameters in
example/clip_fdt/config_cc3m.yaml
:
data.train.data_path ---> the path of the preprocessed CC3M dataset.
data.train.num_samples ---> the total number of samples.
data.train.num_shards --> the number of shards.
saver.save_freq --> the iteration times for saving checkpoints.
lr_scheduler.max_iter --> the total number of iterations for pre-training. It is the product of data.train.data.epoch and the number of iterations for an epoch.
- Run the following command to start pre-training:
bash run.sh example/clip_fdt/train_solver.py \
--config example/clip_fdt/config_cc3m.yaml \
--output_path ${path/to/save/log_and_ckpt} \
--batch_size 128 # batch size for each GPU
Testing
- Change the below parameters in
example/clip_fdt/test.sh
:
MODEL_FOLD: the path to the folder where the pre-trained models and config files are saved.
DATA_FOLD: the path to the folder where the downstream datasets are saved.
- Run the following command:
bash example/clip_fdt/test.sh
Our pre-trained models
Method | Dataset | Model | Total batch size | Epochs | 0-shot IN (ACC) | 0-shot coco (rsum) | Weights |
---|---|---|---|---|---|---|---|
CLIP | CC3M | ViT-B32 | 1024 | 32 | 15.4 | 145.1 | GoogleDriver |
CLIP+FDT | CC3M | ViT-B32 | 1024 | 32 | 18.4 | 189.5 | GoogleDriver |
Acknowledgement
Part of our code is borrowed from the following repositories/sources. We express our gratitude to the authors for releasing their codes.