Home

Awesome

TriCoLo

<a href="https://pytorch.org/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-EE4C2C?style=for-the-badge&logo=pytorch&logoColor=white"></a> <a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/Lightning-792DE4?style=for-the-badge&logo=pytorch-lightning&logoColor=white"></a> <a href="https://wandb.ai/site"><img alt="WandB" src="https://img.shields.io/badge/Weights_&_Biases-FFBE00?style=for-the-badge&logo=WeightsAndBiases&logoColor=white"></a>

This repo is the official implementation for TriCoLo: Trimodal Contrastive Loss for Text to Shape Retrieval

(Paper) (Project Page)

Setup

Conda (recommended)

We recommend the use of miniconda to manage system dependencies.

# create and activate the conda environment
conda create -n tricolo python=3.10
conda activate tricolo

# install PyTorch 2.0.1
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia

# install Python libraries
pip install .

Pip (without conda)

# create and activate the virtual environment
virtualenv --no-download env
source env/bin/activate

# install PyTorch 2.0.1
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2

# install Python libraries
pip install .

Data Preparation

ShapeNet

Download ShapeNet, and place ShapeNetCore.v2 in the data/text2shape-data folder.

Text2Shape (Chair & Table)

  1. Download Text2Shape and place shapenet.json and processed_caption_{train/val/test}.p in the text2shape-data/chair_table folder.

  2. Download ShapeNet solid voxels (Chair & Table):

    cd text2shape-data
    mkdir chair_table
    cd chair_table
    wget http://text2shape.stanford.edu/dataset/shapenet/nrrd_256_filter_div_32_solid.zip
    wget http://text2shape.stanford.edu/dataset/shapenet/nrrd_256_filter_div_64_solid.zip
    wget http://text2shape.stanford.edu/dataset/shapenet/nrrd_256_filter_div_128_solid.zip
    unzip nrrd_256_filter_div_32_solid.zip
    unzip nrrd_256_filter_div_64_solid.zip
    unzip nrrd_256_filter_div_128_solid.zip
    

    Finally, the dataset files should be organized as follows:

    tricolo
    ├── data
    │   ├── preprocess_all_data.py
    │   ├── text2shape-data
    │   │   ├── ShapeNetCore.v2
    │   │   ├── chair_table
    │   │   │   ├── nrrd_256_filter_div_32_solid
    │   │   │   ├── nrrd_256_filter_div_64_solid
    │   │   │   ├── nrrd_256_filter_div_128_solid
    │   │   │   ├── processed_captions_train.p
    │   │   │   ├── processed_captions_val.p
    │   │   │   ├── processed_captions_test.p
    │   │   │   ├── shapenet.json
    
  3. Preprocess the dataset

    python data/preprocess_all_data.py data=text2shape_chair_table +cpu_workers={num_processes}
    
  4. Precache the CLIP embeddings (optional)

    python extract_clip_feats.py data=text2shape_chair_table data.image_size=224
    

Text2Shape (C13)

  1. Download Text2Shape C13.

Training, Inference and Evaluation

Note: Configuration files are managed by Hydra, you can easily add or override any configuration attributes by passing them as arguments.

# log in to WandB
wandb login

# train a model from scratch
# available voxel_encoder_name: SparseCNNEncoder, null
# available image_encoder_name: MVCNNEncoder, CLIPImageEncoder, null
# available text_encoder_name: BiGRUEncoder, CLIPTextEncoder
# available dataset_name: text2shape_chair_table, text2shape_c13
python train.py data={dataset_name} model.voxel_encoder={voxel_encoder_name} \
model.image_encoder={image_encoder_name} model.text_encoder={text_encoder_name} \
experiment_name={any_string}

# train a model from a checkpoint
python train.py data={dataset_name} model.voxel_encoder={voxel_encoder_name} \
model.image_encoder={image_encoder_name} model.text_encoder={text_encoder_name} \
experiment_name={checkpoint_experiment_name} ckpt_name={checkpoint_file_name}

# test a pretrained model
python test.py data={dataset_name} model.voxel_encoder={voxel_encoder_name} \
model.image_encoder={image_encoder_name} model.text_encoder={text_encoder_name} \
experiment_name={checkpoint_experiment_name} +ckpt_path={checkpoint_file_path}

# evaluate inference results
# currently unavailable

Checkpoints

ModalityDatasetSplitRR@1RR@5NDCG@5Download
Tri(I+V)Text2Shape (Chair & Table)Val12.6033.3423.30chair_table_tri.ckpt
Bi(I)Text2Shape (Chair & Table)Val11.6730.6321.49chair_table_bi_i.ckpt
Bi(V)Text2Shape (Chair & Table)Val9.3327.5218.62chair_table_bi_v.ckpt
Tri(I+V)Text2Shape (C13)Val12.9634.8724.19c13_tri.ckpt
Bi(I)Text2Shape (C13)Val11.8933.4822.96c13_bi_i.ckpt
Bi(V)Text2Shape (C13)Val9.7329.2419.69c13_bi_v.ckpt

Acknowledgements

  1. ConVIRT: Our overall training framework is heavily based on the ConVIRT implementation. Paper
  2. MVCNN The MVCNN implementation we used is from this implementation. Paper
  3. Text2Shape: We download the dataset and modify the evaluation code from the original Text2Shape dataset. Paper

We thank the authors for their work and the implementations.