

Optimal Transport Kernel Embedding

The repository implements the Optimal Transport Kernel Embedding (OTKE) described in the following paper

Grégoire Mialon*, Dexiong Chen*, Alexandre d'Aspremont, Julien Mairal. A Trainable Optimal Transport Embedding for Feature Aggregation and its Relationship to Attention. ICLR 2021. <br/>*Equal contribution

TLDR; the paper demonstrates the advantage of our OTK Embedding over usual aggregation methods (e.g, mean pooling, max pooling or attention) when faced with data composed of large sets of features, such as biological sequences, natural language sentences or even images. Our embedding can be learned either with or without labels, which is especially useful when few annotated data are available, and used alone as a kernel method or as a layer in larger models.

A short description about the module

The principal module is implemented in otk/layers.py as OTKernel. It is generally used with a non-linear layer. Combined with the non-linear layer, it takes a sequence or image tensor as input, and performs a non-linear embedding and an adaptive pooling (attention + pooling) based on optimal transport. Specifically, given a sequence x as input, it first computes the optimal transport plan from x to some reference z (left figure). The optimal transport plan, interpreted as the attention score, is then used to obtain a new sequence of the same size as z following a non-linear mapping (right figure). See more details in our paper.

otk otk

OTKernel can be trained in either unsupervised (with K-means) or supervised (like the multi-head self-attention module in Transformer) fashions. It can be used as a module in neural networks, or alone as a kernel method.

Using OTKernel as a module in NNs

Here is an example to use OTKernel in a one-layer model

import torch
from torch import nn
from otk.layers import OTKernel

in_dim = 128
hidden_size = 64
# create an OTK model with single reference and 10 supports
otk_layer = nn.Sequential(
    nn.Linear(in_dim, hidden_size),
    OTKernel(in_dim=hidden_size, out_size=10, heads=1)
# create 2 batches of sequences of L=100 and dim=128
input = torch.rand(2, 100, in_dim)
# each output sequence has L=10 and dim=64
output = otk_layer(input) # 2 x 10 x 64

Using OTKernel alone as a kernel mapping

When using OTKernel alone, the non-linear mapping is a Gaussian kernel (or a convolutional kernel). The full model for sequence is implemented in otk/models.py as SeqAttention. Here is an example

import torch
from otk.models import SeqAttention

in_dim = 128
hidden_size = 64
nclass = 10
# create a classification model based on one CKN and OTK, with filter_size=1 and sigma=0.6 for CKN and with 4 references and 10 supports for OTK
otk_model = SeqAttention(
    in_dim, nclass, [hidden_size], [1], [1], [0.6], out_size=10, heads=4
# create 2 batches of sequences of L=100 and dim=128
input = torch.rand(2, in_dim, 100)
# output: 2 x 10
output = otk_model(input)

Besides training with back-propagation, the above otk_model can be trained without supervision when provided with a data_loader object created by torch.utils.data.DataLoader.

from torch.utils.data import DataLoader

# suppose that we have stored data in dataset
data_loader = DataLoader(dataset, batch_size=256, shuffle=False)


We strongly recommend users to use miniconda to install the following packages (link to pytorch)


Then run



We provide here the commands to reproduce a part of the results in our paper.

Reproducing results for SCOP 1.75

To reproduce the results in Table 2, run the following commands.

Reproducing results for DeepSEA

To reproduce the results (auROC=0.936, auPRC=0.360) in Table 3, run the following commands.

Reproducing results for SST-2

To reproduce the results in Table 4, run the following commands.