Home

Awesome

Spline-Based Convolution Operator of SplineCNN

PyPI Version Testing Status Linting Status Code Coverage


This is a PyTorch implementation of the spline-based convolution operator of SplineCNN, as described in our paper:

Matthias Fey, Jan Eric Lenssen, Frank Weichert, Heinrich MΓΌller: SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels (CVPR 2018)

The operator works on all floating point data types and is implemented both for CPU and GPU.

Installation

Anaconda

Update: You can now install pytorch-spline-conv via Anaconda for all major OS/PyTorch/CUDA combinations πŸ€— Given that you have pytorch >= 1.8.0 installed, simply run

conda install pytorch-spline-conv -c pyg

Binaries

We alternatively provide pip wheels for all major OS/PyTorch/CUDA combinations, see here.

PyTorch 2.4

To install the binaries for PyTorch 2.4.0, simply run

pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.4.0+${CUDA}.html

where ${CUDA} should be replaced by either cpu, cu118, cu121, or cu124 depending on your PyTorch installation.

cpucu118cu121cu124
Linuxβœ…βœ…βœ…βœ…
Windowsβœ…βœ…βœ…βœ…
macOSβœ…

PyTorch 2.3

To install the binaries for PyTorch 2.3.0, simply run

pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.3.0+${CUDA}.html

where ${CUDA} should be replaced by either cpu, cu118, or cu121 depending on your PyTorch installation.

cpucu118cu121
Linuxβœ…βœ…βœ…
Windowsβœ…βœ…βœ…
macOSβœ…

Note: Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, and PyTorch 2.2.0/2.2.1/2.2.2 (following the same procedure). For older versions, you need to explicitly specify the latest supported version number or install via pip install --no-index in order to prevent a manual installation from source. You can look up the latest supported version number here.

From source

Ensure that at least PyTorch 1.4.0 is installed and verify that cuda/bin and cuda/include are in your $PATH and $CPATH respectively, e.g.:

$ python -c "import torch; print(torch.__version__)"
>>> 1.4.0

$ echo $PATH
>>> /usr/local/cuda/bin:...

$ echo $CPATH
>>> /usr/local/cuda/include:...

Then run:

pip install torch-spline-conv

When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail. In this case, ensure that the compute capabilities are set via TORCH_CUDA_ARCH_LIST, e.g.:

export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"

Usage

from torch_spline_conv import spline_conv

out = spline_conv(x,
                  edge_index,
                  pseudo,
                  weight,
                  kernel_size,
                  is_open_spline,
                  degree=1,
                  norm=True,
                  root_weight=None,
                  bias=None)

Applies the spline-based convolution operator

<p align="center"> <img width="50%" src="https://user-images.githubusercontent.com/6945922/38684093-36d9c52e-3e6f-11e8-9021-db054223c6b9.png" /> </p> over several node features of an input graph. The kernel function is defined over the weighted B-spline tensor product basis, as shown below for different B-spline degrees. <p align="center"> <img width="45%" src="https://user-images.githubusercontent.com/6945922/38685443-3a2a0c68-3e72-11e8-8e13-9ce9ad8fe43e.png" /> <img width="45%" src="https://user-images.githubusercontent.com/6945922/38685459-42b2bcae-3e72-11e8-88cc-4b61e41dbd93.png" /> </p>

Parameters

Returns

Example

import torch
from torch_spline_conv import spline_conv

x = torch.rand((4, 2), dtype=torch.float)  # 4 nodes with 2 features each
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])  # 6 edges
pseudo = torch.rand((6, 2), dtype=torch.float)  # two-dimensional edge attributes
weight = torch.rand((25, 2, 4), dtype=torch.float)  # 25 parameters for in_channels x out_channels
kernel_size = torch.tensor([5, 5])  # 5 parameters in each edge dimension
is_open_spline = torch.tensor([1, 1], dtype=torch.uint8)  # only use open B-splines
degree = 1  # B-spline degree of 1
norm = True  # Normalize output by node degree.
root_weight = torch.rand((2, 4), dtype=torch.float)  # separately weight root nodes
bias = None  # do not apply an additional bias

out = spline_conv(x, edge_index, pseudo, weight, kernel_size,
                  is_open_spline, degree, norm, root_weight, bias)

print(out.size())
torch.Size([4, 4])  # 4 nodes with 4 features each

Cite

Please cite our paper if you use this code in your own work:

@inproceedings{Fey/etal/2018,
  title={{SplineCNN}: Fast Geometric Deep Learning with Continuous {B}-Spline Kernels},
  author={Fey, Matthias and Lenssen, Jan Eric and Weichert, Frank and M{\"u}ller, Heinrich},
  booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2018},
}

Running tests

pytest

C++ API

torch-spline-conv also offers a C++ API that contains C++ equivalent of python models.

mkdir build
cd build
# Add -DWITH_CUDA=on support for the CUDA if needed
cmake ..
make
make install