Awesome
Triplet loss in TensorFlow
Author: Olivier Moindrot
This repository contains a triplet loss implementation in TensorFlow with online triplet mining. Please check the blog post for a full description.
The code structure is adapted from code I wrote for CS230 in this repository at tensorflow/vision
.
A set of tutorials for this code can be found here.
Requirements
We recommend using python3 and a virtual environment.
The default venv
should be used, or virtualenv
with python3
.
python3 -m venv .env
source .env/bin/activate
pip install -r requirements_cpu.txt
If you are using a GPU, you will need to install tensorflow-gpu
so do:
pip install -r requirements_gpu.txt
Triplet loss
Triplet loss on two positive faces (Obama) and one negative face (Macron) |
The interesting part, defining triplet loss with triplet mining can be found in model/triplet_loss.py
.
Everything is explained in the blog post.
To use the "batch all" version, you can do:
from model.triplet_loss import batch_all_triplet_loss
loss, fraction_positive = batch_all_triplet_loss(labels, embeddings, margin, squared=False)
In this case fraction_positive
is a useful thing to plot in TensorBoard to track the average number of hard and semi-hard triplets.
To use the "batch hard" version, you can do:
from model.triplet_loss import batch_hard_triplet_loss
loss = batch_hard_triplet_loss(labels, embeddings, margin, squared=False)
Training on MNIST
To run a new experiment called base_model
, do:
python train.py --model_dir experiments/base_model
You will first need to create a configuration file like this one: params.json
.
This json file specifies all the hyperparameters for the model.
All the weights and summaries will be saved in the model_dir
.
Once trained, you can visualize the embeddings by running:
python visualize_embeddings.py --model_dir experiments/base_model
And run tensorboard in the experiment directory:
tensorboard --logdir experiments/base_model
Here is the result (link to gif):
Embeddings of the MNIST test images visualized with T-SNE (perplexity 25) |
Test
To run all the tests, run this from the project directory:
pytest
To run a specific test:
pytest model/tests/test_triplet_loss.py
Resources
- Blog post explaining this project.
- Source code for the built-in TensorFlow function for semi hard online mining triplet loss:
tf.contrib.losses.metric_learning.triplet_semihard_loss
. - Facenet paper introducing online triplet mining
- Detailed explanation of online triplet mining in In Defense of the Triplet Loss for Person Re-Identification
- Blog post by Brandom Amos on online triplet mining: OpenFace 0.2.0: Higher accuracy and halved execution time.
- Source code for the built-in TensorFlow function for semi hard online mining triplet loss:
tf.contrib.losses.metric_learning.triplet_semihard_loss
. - The coursera lecture on triplet loss