Awesome
TriCoLo
<a href="https://pytorch.org/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-EE4C2C?style=for-the-badge&logo=pytorch&logoColor=white"></a> <a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/Lightning-792DE4?style=for-the-badge&logo=pytorch-lightning&logoColor=white"></a> <a href="https://wandb.ai/site"><img alt="WandB" src="https://img.shields.io/badge/Weights_&_Biases-FFBE00?style=for-the-badge&logo=WeightsAndBiases&logoColor=white"></a>
This repo is the official implementation for TriCoLo: Trimodal Contrastive Loss for Text to Shape Retrieval
(Paper) (Project Page)
Setup
Conda (recommended)
We recommend the use of miniconda to manage system dependencies.
# create and activate the conda environment
conda create -n tricolo python=3.10
conda activate tricolo
# install PyTorch 2.0.1
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
# install Python libraries
pip install .
Pip (without conda)
# create and activate the virtual environment
virtualenv --no-download env
source env/bin/activate
# install PyTorch 2.0.1
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
# install Python libraries
pip install .
Data Preparation
ShapeNet
Download ShapeNet, and place ShapeNetCore.v2
in the data/text2shape-data
folder.
Text2Shape (Chair & Table)
-
Download Text2Shape and place
shapenet.json
andprocessed_caption_{train/val/test}.p
in thetext2shape-data/chair_table
folder. -
Download ShapeNet solid voxels (Chair & Table):
cd text2shape-data mkdir chair_table cd chair_table wget http://text2shape.stanford.edu/dataset/shapenet/nrrd_256_filter_div_32_solid.zip wget http://text2shape.stanford.edu/dataset/shapenet/nrrd_256_filter_div_64_solid.zip wget http://text2shape.stanford.edu/dataset/shapenet/nrrd_256_filter_div_128_solid.zip unzip nrrd_256_filter_div_32_solid.zip unzip nrrd_256_filter_div_64_solid.zip unzip nrrd_256_filter_div_128_solid.zip
Finally, the dataset files should be organized as follows:
tricolo ├── data │ ├── preprocess_all_data.py │ ├── text2shape-data │ │ ├── ShapeNetCore.v2 │ │ ├── chair_table │ │ │ ├── nrrd_256_filter_div_32_solid │ │ │ ├── nrrd_256_filter_div_64_solid │ │ │ ├── nrrd_256_filter_div_128_solid │ │ │ ├── processed_captions_train.p │ │ │ ├── processed_captions_val.p │ │ │ ├── processed_captions_test.p │ │ │ ├── shapenet.json
-
Preprocess the dataset
python data/preprocess_all_data.py data=text2shape_chair_table +cpu_workers={num_processes}
-
Precache the CLIP embeddings (optional)
python extract_clip_feats.py data=text2shape_chair_table data.image_size=224
Text2Shape (C13)
- Download Text2Shape C13.
Training, Inference and Evaluation
Note: Configuration files are managed by Hydra, you can easily add or override any configuration attributes by passing them as arguments.
# log in to WandB
wandb login
# train a model from scratch
# available voxel_encoder_name: SparseCNNEncoder, null
# available image_encoder_name: MVCNNEncoder, CLIPImageEncoder, null
# available text_encoder_name: BiGRUEncoder, CLIPTextEncoder
# available dataset_name: text2shape_chair_table, text2shape_c13
python train.py data={dataset_name} model.voxel_encoder={voxel_encoder_name} \
model.image_encoder={image_encoder_name} model.text_encoder={text_encoder_name} \
experiment_name={any_string}
# train a model from a checkpoint
python train.py data={dataset_name} model.voxel_encoder={voxel_encoder_name} \
model.image_encoder={image_encoder_name} model.text_encoder={text_encoder_name} \
experiment_name={checkpoint_experiment_name} ckpt_name={checkpoint_file_name}
# test a pretrained model
python test.py data={dataset_name} model.voxel_encoder={voxel_encoder_name} \
model.image_encoder={image_encoder_name} model.text_encoder={text_encoder_name} \
experiment_name={checkpoint_experiment_name} +ckpt_path={checkpoint_file_path}
# evaluate inference results
# currently unavailable
Checkpoints
Modality | Dataset | Split | RR@1 | RR@5 | NDCG@5 | Download |
---|---|---|---|---|---|---|
Tri(I+V) | Text2Shape (Chair & Table) | Val | 12.60 | 33.34 | 23.30 | chair_table_tri.ckpt |
Bi(I) | Text2Shape (Chair & Table) | Val | 11.67 | 30.63 | 21.49 | chair_table_bi_i.ckpt |
Bi(V) | Text2Shape (Chair & Table) | Val | 9.33 | 27.52 | 18.62 | chair_table_bi_v.ckpt |
Tri(I+V) | Text2Shape (C13) | Val | 12.96 | 34.87 | 24.19 | c13_tri.ckpt |
Bi(I) | Text2Shape (C13) | Val | 11.89 | 33.48 | 22.96 | c13_bi_i.ckpt |
Bi(V) | Text2Shape (C13) | Val | 9.73 | 29.24 | 19.69 | c13_bi_v.ckpt |
Acknowledgements
- ConVIRT: Our overall training framework is heavily based on the ConVIRT implementation. Paper
- MVCNN The MVCNN implementation we used is from this implementation. Paper
- Text2Shape: We download the dataset and modify the evaluation code from the original Text2Shape dataset. Paper
We thank the authors for their work and the implementations.