Home

Awesome

Can Transformers Capture Spatial Relations between Objects?

Chuan Wen, Dinesh Jayaraman, Yang Gao <br/> International Conference on Learning Representation (ICLR) 2024

This is the official codebase for ICLR 2024 paper Can Transformers Capture Spatial Relations between Objects?. This code is based on the Rel3D repository (Great thanks to the authors).

Getting Started

First clone our repository by

git clone git@github.com:AlvinWen428/spatial-relation-benchmark.git

Installation

Use your own environment

This codebase is tested on Ubuntu 20.04 and CUDA 11.3.

We recommend Miniforge for faster installation instead of Anaconda:

mamba create -n srp python=3.8
mamba activate srp
pip install -r requirements.txt

Use with Docker

First build the docker image and tag it as spatial-relation:latest.

docker build . -t spatial-relation:latest

To start the container, you can use the following command:

docker run --gpus all --shm-size=8g -it \
-v $(pwd):/spatial-relation-benchmark \
spatial-relation:latest /bin/bash

Download SpatialSense+, Rel3D Datasets, and the IBOT-pretrained ViT backbone.

Make sure you are in spatial-relation-benchmark, download.py can be used for downloading the SpatialSense+, Rel3D, and the IBOT pretrained backbone:

python download.py --data-key spatialsense+
python download.py --data-key rel3d
python download.py --data-key ibot

Experiments

Training

main.py is the entry of all experiments. All the experiment configs for training and evaluating all the models on SpatialSense+ and Rel3D can be found in configs/. The training process can be executed by:

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port 10000 --nproc_per_node 4 python main.py --exp-config ${CONFIG_PATH} [Args1, Args2, ...]

We provide example scripts in scripts/ to reproduce our Results of RegionViT, CNNTransformer, CrossAttnViT, and RelatiViT. We trained each model for 5 times with different random seeds and reported the average accuracy.

# Rel3D
python scripts/rel3d_regionvit.py
python scripts/rel3d_cnntransformer.py
python scripts/rel3d_crossattnvit.py
python scripts/rel3d_relativit.py

# SpatialSense+
python scripts/spatialsense_regionvit.py
python scripts/spatialsense_cnntransformer.py
python scripts/spatialsense_crossattnvit.py
python scripts/spatialsense_relativit.py
<img src="doc/architectures.png"/>

Testing

If you have trained the models with the provided scripts in scripts/, the checkpoints and tensorboard files can be found in results/. Then, taking RelatiViT on SpatialSense+ as an example, the average testing results over different random seeds can be computed by:

python main.py --entry batch-test --model-path results/spatialsenseplus_RelatiViT_seed*

Citation

If you find our project useful, please consider citing our paper:

@inproceedings{
  wen2024can,
  title={Can Transformers Capture Spatial Relations between Objects?},
  author={Chuan Wen and Dinesh Jayaraman and Yang Gao},
  booktitle={The Twelfth International Conference on Learning Representations},
  year={2024},
  url={https://openreview.net/forum?id=HgZUcwFhjr}
}