Home

Awesome

Learning Embeddings with Centroid Triplet Loss for Object Identification in Robotic Grasping

This repository contains the code for the paper "Learning Embeddings with Centroid Triplet Loss for Object Identification in Robotic Grasping"

[Arxiv] [Code]

Anas Gouda, Max Schwarz, Christopher Reining, Sven Behnke, Alice Kirchheim,

This repository contains the training and evaluation code for the paper. If you are only interested in using the pre-trained identification model or our whole pipeline in your project, you can refer to our DoUnseen python package.

Installation

TODO

Training

We use ViT model pre-trained on image-net. Our final ViT model was trained on 8 NVIDIA A100 GPUs. Install ViT pretrained model from PyTorch Hub:

cd training
torchrun --standalone --nnodes=1 --nproc-per-node=8 train_ctl_model.py --dataset /PATH/TO/ARMBENCH_OBJECT_IDENT_DATASET --batch-size 128 --num-workers 4 --freeze-first-layer --name augment --lr 0.05 --epochs 200 --augment

To continue training from a checkpoint use the argument --continue-from PATH/TO/CHECKPOINT.model. The checkpoint contains the model weights, args and the last epoch number.

Evaluation

Evaluation on ARMBENCH object identification task

To evaluate our pretrained models on the ARMBENCH object identification task:

  1. Make a new directory models in the root directory.
mkdir models
  1. Download our pre-trained model (ViT/ResNet) from here inside the models directory.
  2. Run the following command to evaluate the model:
python evaluation/evaluate_ctl.py --pretrained_model_type 'ViT-B_16' --pretrained_model_path 'models/vit_b_16_ctl_augment_epoch_199.pth' --armbench_test_batch_size 5 --num_workers 8 --armbench_dataset_path /PATH/TO/ARMBENCH_OBJECT_IDENT_DATASET

Citation

If you find this work useful, please consider citing:

@misc{gouda2024learning,
      title={Learning Embeddings with Centroid Triplet Loss for Object Identification in Robotic Grasping}, 
      author={Anas Gouda and Max Schwarz and Christopher Reining and Sven Behnke and Alice Kirchheim},
      year={2024},
      eprint={2404.06277},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}