Awesome
<p align="center"> <img src="demo_images/logo.png" width="15%"/> </p>SpatialRGPT: Grounded Spatial Reasoning in Vision Language Models (NeurIPS'24)
💡 Introduction
SpatialRGPT: Grounded Spatial Reasoning in Vision-Language Models <br> An-Chieh Cheng, Hongxu (Danny) Yin, Yang Fu, Qiushan Guo, Ruihan Yang, Jan Kautz, Xiaolong Wang, Sifei Liu <br>
SpatialRGPT is a powerful vision-language model adept at understanding both 2D and 3D spatial arrangements. It can process any region proposal, such as boxes or masks, and provide answers to complex spatial reasoning questions.
📢 News
- Oct-07-24- SpatialRGPT code/dataset/benchmark released! 🔥
- Sep-25-24- We're thrilled to share that SpatialRGPT has been accepted to NeurIPS 2024! 🎊
Installation
To build environment for training SpatialRGPT, please run the following:
./environment_setup.sh srgpt
conda activate srgpt
Gradio Demo
To run the Gradio demo for SpatialRGPT, please follow these steps. Due to pydantic
version conflicts, the demo environment is not compatible with the training environment. Therefore, a separate environment will need to be created for the Gradio demo.
-
Build the environment.
./environment_setup.sh srgpt-demo conda activate srgpt-demo pip install gradio==4.27 deepspeed==0.13.0 gradio_box_promptable_image segment_anything_hq pip install -U 'git+https://github.com/facebookresearch/detectron2.git@ff53992b1985b63bd3262b5a36167098e3dada02'
If you run into an error with the detectron2 installation, it could be because
CUDA_HOME
is not set. To fix this, exportCUDA_HOME
to your local CUDA path. See details in this issue. -
Clone the Depth-Anything repository and download the necessary checkpoint:
git clone https://github.com/LiheYoung/Depth-Anything.git wget https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth
Place
depth_anything_vitl14.pth
underDepth-Anything/checkpoints
, and set the path to the environment variable. For example:export DEPTH_ANYTHING_PATH=/YOUR_OWN_PATH/Depth-Anything
-
Download SAM-HQ checkpoint from here, and set the path to the environment variable. For example:
export SAM_CKPT_PATH=/YOUR_OWN_PATH/sam_hq_vit_h.pth
-
Launch Gradio server. You can use your own checkpoint, or use
a8cheng/SpatialRGPT-VILA1.5-8B
cd demo python gradio_web_server_multi.py --model-path PATH_TO_CHECKPOINT
Training
SpatialRGPT follows VILA training, which contains three steps. We provide training script for three different LLM models, sheared_3b
, llama2_7b
, llama3_8b
. You can find the training scripts for each stage in the scripts/srgpt folder.
Open Spatial Dataset
Please download the Open Spatial Dataset from huggingface, and modify the path in llava/data/dataset_mixture.py
.
For raw images, please download OpenImages from OpenImagesV7. To process the rgb images into depth, we use DepthAnythingV2 and save the depth with the following function:
<details> <summary>Click to expand</summary>def save_raw_16bit(depth, fpath, height, width):
depth = F.interpolate(depth[None, None], (height, width), mode='bilinear', align_corners=False)[0, 0]
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.cpu().numpy().astype(np.uint8)
colorized_depth = np.stack([depth, depth, depth], axis=-1)
depth_image = Image.fromarray(colorized_depth)
depth_image.save(fpath)
</details>
Dataset Synthesis Pipeline
We've also made the dataset synthesis pipeline available. You can find the code and instructions in the dataset_pipeline folder. Please note that some of the packages we use have had version updates, and we've migrated to their latest versions. This may result in some bugs. Feel free to report any issues or unexpected results you encounter.
<p align="center"> <img src="dataset_pipeline/asssets/wis3d-demo.gif" alt="Wis3D Demo"> </p>Evaluations
Our evaluation scripts takes the following arguments, PATH_TO_CKPT
, CKPT_NAME
, CONV_TYPE
.
PATH_TO_CKPT
refers to the location of the checkpoint you want to evaluate.CKPT_NAME
specifies the folder that will be created in theeval_out
directory, where the evaluation results will be stored.- Make sure that
CONV_TYPE
matches the conversation type used in the checkpoint. Forllama3_8b
, please usellama_3
.
Region Classification
First, prepare the evaluation annotation following RegionCLIP.
Then, use scripts/srgpt/eval/coco_cls.sh PATH_TO_CKPT CKPT_NAME CONV_TYPE
.
SpatialRGPT-Bench Evaluation
First, download the images from omni3d, following there instructions. Then download annotations from https://huggingface.co/datasets/a8cheng/SpatialRGPT-Bench. Modify the path in scripts/srgpt/eval/srgpt_bench.sh
to corresponding paths.
Note that for SpatialRGPT-Bench, you need to clone the Depth-Anything repository and download the necessary checkpoint:
git clone https://github.com/LiheYoung/Depth-Anything.git
wget https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth
Place depth_anything_vitl14.pth
under Depth-Anything/checkpoints
, and set the path to the environment variable.
export DEPTH_ANYTHING_PATH="PATH_TO_DEPTHANYTHING"
Then use scripts/srgpt/eval/srgpt_bench.sh PATH_TO_CKPT CKPT_NAME CONV_TYPE
.
General VLM Benchmarks
Our code is compatible with VILA's evaluation scripts. See VILA/evaluations for details.
📜 Citation
@inproceedings{cheng2024spatialrgpt,
title={SpatialRGPT: Grounded Spatial Reasoning in Vision-Language Models},
author={Cheng, An-Chieh and Yin, Hongxu and Fu, Yang and Guo, Qiushan and Yang, Ruihan and Kautz, Jan and Wang, Xiaolong and Liu, Sifei},
booktitle={NeurIPS},
year={2024}
}
🙏 Acknowledgement
We have used code snippets from different repositories, especially from: VILA, Omni3D, GLaMM, VQASynth, and ConceptGraphs. We would like to acknowledge and thank the authors of these repositories for their excellent work.