Home

Awesome

TRAnsformer Routing Networks (TRAR)

This is an official implementation for ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visual Question Answering". It currently includes the code for training TRAR on VQA2.0 and CLEVR dataset. Our TRAR model for REC task is coming soon.

Updates

Introduction

TRAR vs Standard Transformer

<p align="center"> <img src="misc/trar_block.png" width="550"> </p>

TRAR Overall

<p align="center"> <img src="misc/trar_overall.png" width="550"> </p>

Table of Contents

  1. Installation
  2. Dataset setup
  3. Config Introduction
  4. Training
  5. Validation and Testing
  6. Models

Installation

git clone https://github.com/rentainhe/TRAR-VQA.git
cd TRAR-VQA
conda create -n trar python=3.7 -y
conda activate trar
conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
pip install -r requirements.txt
wget https://github.com/explosion/spacy-models/releases/download/en_vectors_web_lg-2.1.0/en_vectors_web_lg-2.1.0.tar.gz -O en_vectors_web_lg-2.1.0.tar.gz
pip install en_vectors_web_lg-2.1.0.tar.gz

Dataset setup

see DATA.md

Config Introduction

In trar.yml config we have these specific settings for TRAR model

ORDERS: [0, 1, 2, 3]
IMG_SCALE: 8 
ROUTING: 'hard' # {'soft', 'hard'}
POOLING: 'attention' # {'attention', 'avg', 'fc'}
TAU_POLICY: 1 # {0: 'SLOW', 1: 'FAST', 2: 'FINETUNE'}
TAU_MAX: 10
TAU_MIN: 0.1
BINARIZE: False

Note that please set BINARIZE=False when ROUTING='soft', it's no need to binarize the path prob in soft routing block.

TAU_POLICY visualization

For MAX_EPOCH=13 with WARMUP_EPOCH=3 we have the following policy strategy:

<p align="center"> <img src="misc/policy_visualization.png" width="550"> </p>

Training

Train model on VQA-v2 with default hyperparameters:

python3 run.py --RUN='train' --DATASET='vqa' --MODEL='trar'

and the training log will be seved to:

results/log/log_run_<VERSION>.txt

Args:

Resume Training

Resume training from specific saved model weights

python3 run.py --RUN='train' --DATASET='vqa' --MODEL='trar' --RESUME=True --CKPT_V=str --CKPT_E=int

Multi-GPU Training and Gradient Accumulation

  1. Multi-GPU Training: Add --GPU='0, 1, 2, 3...' after the training scripts.
python3 run.py --RUN='train' --DATASET='vqa' --MODEL='trar' --GPU='0,1,2,3'

The batch size on each GPU will be divided into BATCH_SIZE/GPUs automatically.

  1. Gradient Accumulation: Add --ACCU=n after the training scripts
python3 run.py --RUN='train' --DATASET='vqa' --MODEL='trar' --ACCU=2

This makes the optimizer accumulate gradients for n mini-batches and update the model weights once. BATCH_SIZE should be divided by n.

Validation and Testing

Warning: The args --MODEL and --DATASET should be set to the same values as those in the training stage.

Validate on Local Machine Offline evaluation only support the evaluations on the coco_2014_val dataset now.

  1. Use saved checkpoint
python3 run.py --RUN='val' --MODEL='trar' --DATASET='{vqa, clevr}' --CKPT_V=str --CKPT_E=int
  1. Use the absolute path
python3 run.py --RUN='val' --MODEL='trar' --DATASET='{vqa, clevr}' --CKPT_PATH=str

Online Testing All the evaluations on the test dataset of VQA-v2 and CLEVR benchmarks can be achieved as follows:

python3 run.py --RUN='test' --MODEL='trar' --DATASET='{vqa, clevr}' --CKPT_V=str --CKPT_E=int

Result file are saved at:

results/result_test/result_run_<CKPT_V>_<CKPT_E>.json

You can upload the obtained result json file to Eval AI to evaluate the scores.

Models

Here we provide our pretrained model and log, please see MODEL.md

Acknowledgements

Citation

if TRAR is helpful for your research or you wish to refer the baseline results published here, we'd really appreciate it if you could cite this paper:

@InProceedings{Zhou_2021_ICCV,
    author    = {Zhou, Yiyi and Ren, Tianhe and Zhu, Chaoyang and Sun, Xiaoshuai and Liu, Jianzhuang and Ding, Xinghao and Xu, Mingliang and Ji, Rongrong},
    title     = {TRAR: Routing the Attention Spans in Transformer for Visual Question Answering},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {2074-2084}
}