

Deep Metric Learning Research in PyTorch

What can I find here?

This repository contains all code and implementations used in:

Revisiting Training Strategies and Generalization Performance in Deep Metric Learning

accepted to ICML 2020.

Link: https://arxiv.org/abs/2002.08473

The code is meant to serve as a research starting point in Deep Metric Learning. By implementing key baselines under a consistent setting and logging a vast set of metrics, it should be easier to ensure that method gains are not due to implementational variations, while better understanding driving factors.

It is set up in a modular way to allow for fast and detailed prototyping, but with key elements written in a way that allows the code to be directly copied into other pipelines. In addition, multiple training and test metrics are logged in W&B to allow for easy and large-scale evaluation.

Finally, please find a public W&B repo with key runs performed in the paper here: https://app.wandb.ai/confusezius/RevisitDML.

Some Notes:

    title={Revisiting Training Strategies and Generalization Performance in Deep Metric Learning},
    author={Karsten Roth and Timo Milbich and Samarth Sinha and Prateek Gupta and Björn Ommer and Joseph Paul Cohen},

Paper-related Information

Reproduce results from our paper Revisiting Training Strategies and Generalization Performance in Deep Metric Learning

Note: There may be small deviations in results based on the Hardware (e.g. between P100 and RTX GPUs) and Software (different PyTorch/Cuda versions) used to run these experiments, but they should be covered in the standard deviations reported in the paper.

How to use this Repo


An exemplary setup of a virtual environment containing everything needed:

(1) wget  https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
(2) bash Miniconda3-latest-Linux-x86_64.sh (say yes to append path to bashrc)
(3) source .bashrc
(4) conda create -n DL python=3.6
(5) conda activate DL
(6) conda install matplotlib scipy scikit-learn scikit-image tqdm pandas pillow
(7) conda install pytorch torchvision faiss-gpu cudatoolkit=10.0 -c pytorch
(8) pip install wandb pretrainedmodels
(9) Run the scripts!


Data for

can be downloaded either from the respective project sites or directly via Dropbox:

The latter ensures that the folder structure is already consistent with this pipeline and the dataloaders.

Otherwise, please make sure that the datasets have the following internal structure:

|    └───001.Black_footed_Albatross
|           │   Black_Footed_Albatross_0001_796111
|           │   ...
|    ...
|    └───bicycle_final
|           │   111085122871_0.jpg
|    ...
|    │   bicycle.txt
|    │   ...

Assuming your folder is placed in e.g. <$datapath/cub200>, pass $datapath as input to --source.


Training is done by using main.py and setting the respective flags, all of which are listed and explained in parameters.py. A vast set of exemplary runs is provided in Revisit_Runs.sh.

[I.] A basic sample run using default parameters would like this:

python main.py --loss margin --batch_mining distance --log_online \
              --project DML_Project --group Margin_with_Distance --seed 0 \
              --gpu 0 --bs 112 --data_sampler class_random --samples_per_class 2 \
              --arch resnet50_frozen_normalize --source $datapath --n_epochs 150 \
              --lr 0.00001 --embed_dim 128 --evaluate_on_gpu

The purpose of each flag explained:

Some Notes:

[II.] Advanced Runs:

python main.py --loss margin --batch_mining distance --loss_margin_beta 0.6 --miner_distance_lower_cutoff 0.5 ... (basic parameters)

Evaluating Results with W&B

  1. Create custom objectives: Simply take a look at e.g. criteria/margin.py, and ensure that the used methods has the following properties:
  1. Create custom batchminer: Simply take a look at e.g. batch_mining/distance.py - The miner needs to be a class with a defined __call__()-function, taking in a batch and labels and returning e.g. a list of triplets.

  2. Create custom datasamplers:Simply take a look at e.g. datasampler/class_random_sampler.py. The sampler needs to inherit from torch.utils.data.sampler.Sampler and has to provide a __iter__() and a __len__() function. It has to yield a set of indices that are used to create the batch.

Implemented Methods

For a detailed explanation of everything, please refer to the supplementary of our paper!

DML criteria

Evaluation Metrics

Metrics based on Euclidean Distances

Metrics based on Cosine Similarities (not included by default)

