Home

Awesome

Multimodal Graph Learning

Most multimodal learning algorithms focus on modeling simple one-to-one pairs of data from two modalities, such as image-caption pairs, or audiotext pairs. However, in most real-world settings, entities of different modalities interact with each other in more complex and multifaceted ways, going beyond one-to-one mappings.

We propose Multimodal Graph Learning (MMGL), a systematic framework for capturing information from multiple multimodal neighbors with relational structures among them. In particular, we focus on MMGL for generative tasks, building upon pretrained Language Models (LMs), aiming to augment their text generation with multimodal neighbor contexts.

The original paper can be found at MMGL. The initial version of implementation including graph encodings can be found at research-MMHG.

Setup

Create a new conda environment, install PyTorch and the remaining requirements:

conda create python==3.7 -n mmgl
conda activate mmgl
pip install -r requirements.txt
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117

The code is implemented on PyTorch DistributedDataParallel. The code supports the WikiWeb2M dataset.

Data preprocessing

First, make a folder to download the WikiWeb2M dataset: mkdir wikiweb2m/raw. Then download all Train/Validation/Test files from the WikiWeb2M into wikiweb2m/raw. Next, make a folder to download images: mkdir wikiweb2m/raw/images. Finally, run preprocess_data.py to convert the WikiWeb2M dataset into pytorch format.

python preprocess_data.py

The output training/validation/test set sizes for section summarization is as follows:

Number ofTrainValidationTest
Sections680K170K170K

Training

Script

In script/train_generation.sh, you can specify the base model (MODEL_NAME), the task (TASK; currently we support only section summarization 'section'), the neighbor context (CONTEXT). For CONTEXT, there are four options as follows:

CONTEXTdescription
section_onlyuse only text in the target section
section_alluse text and images in the target section
text_onlyuse only text in the all page
alluse text and images in the all page

You can set how to encode text neighbors using NEIGHBOR_MODE. There are two options as follows:

NEIGHBOR_MODEdescription
rawconcatenate text neighbors as raw text into the input text
embeddingembed text neighbors using text_model and concatenate embeddings into the input text

You can set the parameter-efficient fine-tuning (PEFT) option in the script using PEFT_TYPE. There are four PEFT options.

CONTEXTdescription
nonefull finetune
promptprompt tuning
prefixprefix tuning
loraLoRA
flamingofine-tune only newly added cross-attention; can be used on decode-only models with neighbor_mode = embedding

In the script, you can change max_input_length and max_output_length in addition to other optimization hyperparameters (e.g., epochs, learning_rate, per_device_train_batch_size). You can set which models to encode text and image neighbors using text_model and visual_model. All arguments you can set are defined under Argument class in language_modelling/run_generation.py.

File description

We provide brief descriptions for each file as follows:

Directory/Filedescription
wikiweb2m/codes related to WikiWeb2M dataset
wikiweb2m/cidercompute CIDEr scores
wikiweb2m/data.pyprepare each training point based on context and neighbor_mode
wikiweb2m/preprocess_data.pycodes to preprocess WikiWeb2M dataset and download images
script/codes to run MMGL
script/train_generation.shset hyperparameters
language_modelling/main directory
language_modelling/run_generation.pyprepare models, read datasets, train/validation loops
language_modelling/utils.pyutility functions
model/language models
model/modelling_self_attention.pyLMs only with self-attention; including encoder-decoder and decoder-only models
model/modelling_cross_attention.pyLMs with cross-attention to encode neighbor information; decoder-only models

Citation

If you find this work or our code useful, please consider citing:

@article{yoon2023multimodal,
  title={Multimodal Graph Learning for Generative Tasks},
  author={Yoon, Minji and Koh, Jing Yu and Hooi, Bryan and Salakhutdinov, Ruslan},
  journal={arXiv preprint arXiv:2310.07478},
  year={2023}
}