Home

Awesome

logo


<!-- [![Build status](https://ci.inria.fr/keops/buildStatus/icon?job=keops%2Fmain)](https://ci.inria.fr/keops/job/keops/job/main/) -->

PyPI version PyPI downloads

<!-- [![CRAN version](https://img.shields.io/cran/v/rkeops?color=yellowgreen)](https://cran.r-project.org/web/packages/rkeops/index.html) --> <!-- [![CRAN downloads](http://cranlogs.r-pkg.org/badges/grand-total/rkeops?color=yellowgreen)](https://cran.r-project.org/web/packages/rkeops/index.html) -->

Please visit our website for documentation, contribution guidelines and tutorials.

Kernel Operations on the GPU, with autodiff, without memory overflows

The KeOps library lets you compute reductions of large arrays whose entries are given by a mathematical formula or a neural network. It combines efficient C++ routines with an automatic differentiation engine and can be used with Python (NumPy, PyTorch), Matlab and R.

It is perfectly suited to the computation of kernel matrix-vector products, K-nearest neighbors queries, N-body interactions, point cloud convolutions and the associated gradients. Crucially, it performs well even when the corresponding kernel or distance matrices do not fit into the RAM or GPU memory. Compared with a PyTorch GPU baseline, KeOps provides a x10-x100 speed-up on a wide range of geometric applications, from kernel methods to geometric deep learning.

Symbolic matrices

Why using KeOps? Math libraries represent most objects as matrices and tensors:

In practice, KeOps symbolic tensors are both fast and memory-efficient. We take advantage of the structure of CUDA registers to bypass costly memory transfers between arithmetic and memory circuits. This allows us to provide a x10-x100 speed-up to PyTorch GPU programs in a wide range of settings.

Using our Python interface, a typical sample of code looks like:

# Create two arrays with 3 columns and a (huge) number of lines, on the GPU
import torch  # NumPy, Matlab and R are also supported
M, N, D = 1000000, 2000000, 3
x = torch.randn(M, D, requires_grad=True).cuda()  # x.shape = (1e6, 3)
y = torch.randn(N, D).cuda()                      # y.shape = (2e6, 3)

# Turn our dense Tensors into KeOps symbolic variables with "virtual"
# dimensions at positions 0 and 1 (for "i" and "j" indices):
from pykeops.torch import LazyTensor
x_i = LazyTensor(x.view(M, 1, D))  # x_i.shape = (1e6, 1, 3)
y_j = LazyTensor(y.view(1, N, D))  # y_j.shape = ( 1, 2e6,3)

# We can now perform large-scale computations, without memory overflows:
D_ij = ((x_i - y_j)**2).sum(dim=2)  # Symbolic (1e6,2e6,1) matrix of squared distances
K_ij = (- D_ij).exp()               # Symbolic (1e6,2e6,1) Gaussian kernel matrix

# We come back to vanilla PyTorch Tensors or NumPy arrays using
# reduction operations such as .sum(), .logsumexp() or .argmin()
# on one of the two "symbolic" dimensions 0 and 1.
# Here, the kernel density estimation   a_i = sum_j exp(-|x_i-y_j|^2)
# is computed using a CUDA scheme that has a linear memory footprint and
# outperforms standard PyTorch implementations by two orders of magnitude.
a_i = K_ij.sum(dim=1)  # Genuine torch.cuda.FloatTensor, a_i.shape = (1e6, 1), 

# Crucially, KeOps fully supports automatic differentiation!
g_x = torch.autograd.grad((a_i ** 2).sum(), [x])

KeOps allows you to get the most out of your hardware without compromising on usability. It provides:

More details are provided below:

Projects using KeOps

Symbolic matrices are to geometric learning what sparse matrices are to graph processing. KeOps can thus be used in a wide range of settings, from shape analysis (registration, geometric deep learning, optimal transport...) to machine learning (kernel methods, k-means, UMAP...), Gaussian processes, computational biology and physics.

KeOps provides core routines for the following projects and libraries:

Licensing, citation, academic use

This library is licensed under the permissive MIT license, which is fully compatible with both academic and commercial applications.

If you use this code in a research paper, please cite our original publication:

Charlier, B., Feydy, J., Glaunès, J. A., Collin, F.-D. & Durif, G. Kernel Operations on the GPU, with Autodiff, without Memory Overflows. Journal of Machine Learning Research 22, 1–6 (2021).

@article{JMLR:v22:20-275,
  author  = {Benjamin Charlier and Jean Feydy and Joan Alexis Glaunès and François-David Collin and Ghislain Durif},
  title   = {Kernel Operations on the GPU, with Autodiff, without Memory Overflows},
  journal = {Journal of Machine Learning Research},
  year    = {2021},
  volume  = {22},
  number  = {74},
  pages   = {1-6},
  url     = {http://jmlr.org/papers/v22/20-275.html}
}

For applications to geometric (deep) learning, you may also consider our NeurIPS 2020 paper:

@article{feydy2020fast,
    title={Fast geometric learning with symbolic matrices},
    author={Feydy, Jean and Glaun{\`e}s, Joan and Charlier, Benjamin and Bronstein, Michael},
    journal={Advances in Neural Information Processing Systems},
    volume={33},
    year={2020}
}

Authors

Please contact us for any bug report, question or feature request by filing a report on our GitHub issue tracker!

Core library - KeOps, PyKeOps, KeOpsLab:

R bindings - RKeOps:

Contributors:

Beyond explicit code contributions, KeOps has grown out of numerous discussions with applied mathematicians and machine learning experts. We would especially like to thank Alain Trouvé, Stanley Durrleman, Gabriel Peyré and Michael Bronstein for their valuable suggestions and financial support.

KeOps was awarded an open science prize by the French Ministry of Higher Education and Research in 2023 ("Espoir - Documentation").