Home

Awesome

Shaking Up VLMs: Comparing Transformers and Structured State Space Models for Vision & Language Modeling (EMNLP 2024)

[Paper][Model Checkpoints][Data][Training]

Requirements

conda create -p ~/sharedscratch/conda_envs/athnlp2024-cu12 python=3.10

# Install dependencies
poetry install

# Install flash attention / mamba

poe install-flash-attn
poe install-mamba-ssm
poe install-causal-conv1d

Model Checkpoints 🤗

Pretrained Checkpoints

Model
Pythia-VL-1B
Mamba-VL-790M
Pythia-VL-1.4B
Mamba-VL-1.4B
Pythia-VL-2.8B
Mamba-VL-2.8B

Instruction-tuned Checkpoints

ModelCOCONoCapsVQAv2GQAV7W (test-T)VSRPOPERefCOCO (testA)RefCOCO (testB)RefCOCO+ (testA)RefCOCO+ (testB)RefCOCOgV7W (test-P)TextCapsTextVQAAI2D
Pythia-VL-1B132.8997.6172.2653.7981.9672.4386.7776.0062.4845.3647.4467.5883.7892.7335.2277.62
Mamba-VL-790M133.8199.0071.6754.9581.8275.3986.7767.8456.3557.9741.4359.1674.0194.3040.7279.27
Pythia-VL-1.4B134.06100.7273.5757.0583.0677.7286.4082.4368.3972.3555.1672.5686.1394.6037.5479.27
Mamba-VL-1.4B134.76100.5674.4658.4483.7880.1885.3276.6063.4868.4052.1168.8280.1898.6841.3080.86
Pythia-VL-2.8B134.97101.2775.0859.7684.3480.8686.8785.3970.8275.3958.6276.2486.6199.7439.1481.57
Mamba-VL-2.8B135.53102.0076.0860.4185.3181.4587.3379.2964.9771.6453.9471.2782.50100.4742.1483.71

Data

Instructions for specific datasets

GRIT

GRIT is downloaded using img2dataset. Note that some of the urls may not be available by the time of the downloading

./scripts/download_grit.sh storage/datasets/grit_url_folder storage/datasets/grit

To avoid training on the whole data, filter out grit by the noun_phrases (see appendix in the paper for full details)

python prepare_grit_dataset.py \
	--cache_dir /path/to/downloaded/grit \
	--output_folder /path/to/downsampled/grit \
	--downsample_images \
	--check_overlap 

OCRVQA

We also filter out examples from OCRVQA (see appendix in the paper for details)

python filter_out_ocrvqa_images.py \
	--cache_dir /path/to/downloaded/ocrvqa \
	--output_json /path/to/filtered/ocrvqa/examples \

Prepare pretraining dataset

python prepare_dataset.py \
	--dataset_subset llava_pretrain \
	--root_dataset_path storage/datasets \
	--cache_dir storage/datasets/vl_mamba \

Prepare instruction tuning dataset

python prepare_dataset.py \
	--dataset_subset instruction_tuning \
	--root_dataset_path storage/datasets \
	--cache_dir storage/datasets/vl_mamba \

Prepare a single dataset

python prepare_dataset.py \
	--dataset_subset coco \
	--root_dataset_path storage/datasets \
	--cache_dir storage/datasets/vl_mamba \

see DatasetNames in src/vl_mamba/datamodels/datamodels.py for the names of different datasets

Training

Pretraining

Pythia

./scripts/pretrain_pythia.sh

Mamba

./scripts/pretrain_mamba.sh

Instruction-tuning

Pythia

./scripts/finetune_pythia.sh path/to/pretrained/pythia/model /path/to/dataset/cache /path/to/root/dataset/path /output/model/directory wandb_run_name

Mamba

./scripts/finetune_mamba.sh path/to/pretrained/mamba/model /path/to/dataset/cache /path/to/root/dataset/path /output/model/directory wandb_run_name

Training logs

All the logs regarding pretraining / finetuning can be found on wandb Note that some of the runs were resumed from a previous checkpoint.