Characterizing Generalization under Out-of-Distribution shifts in Deep Metric Learning

Links: [Paper]


This repository contains code and implementations used in our Generalization Study on Deep Metric Learning under OOD Shifts. The underlying DML pipeline is adapted from this repository, and introduces the following novel key elements from our paper:

If you wish to use the ooDML splits with your own codebase, simply copy the corresponding pickle-files from datasplits. The internal structure of the pickled dicts is explained below!

Quick Usage


An exemplary setup of a virtual environment containing everything needed:

(1) wget  https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
(2) bash Miniconda3-latest-Linux-x86_64.sh (say yes to append path to bashrc)
(3) source .bashrc
(4) conda create -n DL python=3.6
(5) conda activate DL
(6) conda install matplotlib scipy scikit-learn scikit-image tqdm pandas pillow
(7) conda install pytorch torchvision faiss-gpu cudatoolkit=10.0 -c pytorch
(8) pip install wandb pretrainedmodels
(9) Run the scripts!

The ooDML benchmark

Information about the utilised split progressions is fully available in the .pkl-files located in the datasplits directory. For each base benchmark, the respective .pkl-file contains a dictionary with the following structure:

Split_ID: 1, 2, ..., 8/9 (depending on benchmark)
└─── train: list of training classes.
|       └─── String of classname
|       ...
└─── test: List of test classes.
|       └─── String of classname
|       ...
└─── fid: FID score (using R50) for given split.
└─── test_episodes: Episode data for few-shot evaluation of DML, given per Shot-setup.
|       └─── 2
|       |    └─── Episode-ID: 1, ..., 10
|       |    |       └─── classname
|       |    |       |       └─── Support samples to use for few-shot adaptation. The complement will be used to generate the query data.
|       |    |       |       ...
|       └─── 5 (see 2-Shot setting)
|       ...
|       └─── 10 (not used in the paper)
|       ...
└─── split_train, split_val: Train/val splits for hyperparameter tuning before running on final test set. Also used in few-shot experiments to provide default train/validation splits.
|       ...

Detailed Usage

There are four key scripts in this repository:

Here are exemplary uses for each script: 1.

python ood_main.py --checkpoint --data_hardness $split_id --kernels 6 --source $datapath --n_epochs 200 --log_online --project DML_OOD-Shift_Study --group CUB_ID-1_Margin_b12_Distance --seed 0 --gpu $gpu --bs 112 --samples_per_class 2 --loss margin --batch_mining distance --arch resnet50_frozen_normalize --embed_dim 512
python ood_diva_main.py --checkpoint --data_hardness $split_id --kernels 6 --source $datapath --n_epochs 200 --log_online --project DML_OOD-Shift_Study --group CUB_ID-1_DiVA --seed 0 --gpu $gpu --bs 108 --samples_per_class 2 --loss margin --batch_mining distance --diva_rho_decorrelation 1500 1500 1500 --diva_alpha_ssl 0.3 --diva_alpha_intra 0.3 --diva_alpha_shared 0.3 --arch resnet50_frozen_normalize --embed_dim 128
python fewshot_ood_main.py --dataset cars196 --finetune_criterion margin --finetune_shots 2 --finetune_lr_multi 10 --finetune_iter 1000 --finetune_only_last --checkpoint --data_hardness -20 --kernels 6 --source $datapath --n_epochs 200 --log_online --project DML_OOD-Shift_FewShot_Study --group CAR_Shots-2_ID-1_Multisimilarity --seed 0 --gpu $gpu --bs 112 --samples_per_class 2 --loss multisimilarity --arch resnet50_frozen_normalize --embed_dim 512
python fewshot_ood_main.py --dataset cars196 --finetune_criterion margin --finetune_shots 2 --finetune_lr_multi 10 --finetune_iter 1000 --finetune_only_last --checkpoint --data_hardness -20 --kernels 6 --source $datapath --n_epochs 200 --log_online --project DML_OOD-Shift_FewShot_Study --group CAR_Shots-2_ID-1_Multisimilarity --seed 0 --gpu $gpu --bs 112 --samples_per_class 2 --loss multisimilarity --arch resnet50_frozen_normalize --embed_dim 512

An explanation for all utilised flags is provided in the respective help-string in parameters.py.

As noted previously, the data is assumed to have the following structure:

|    └───001.Black_footed_Albatross
|           │   Black_Footed_Albatross_0001_796111
|           │   ...
|    ...
|    └───bicycle_final
|           │   111085122871_0.jpg
|    ...
|    │   bicycle.txt
|    │   ...

Assuming your folder is placed in e.g. <$datapath/cub200>, pass $datapath as input to --source.


