Home

Awesome

E(n)-equivariant Steerable CNNs (escnn)

Documentation | escnn library |

:rocket: ~20% faster than pytorch*

escnn_jax is a Jax port of the PyTorch escnn library for equivariant deep learning. escnn_jax supports steerable CNNs equivariant to both 2D and 3D isometries, as well as equivariant MLPs.


The library is structured into four subpackages with different high-level features:

ComponentDependencyDescription
escnn.groupPure Pythonimplements basic concepts of group and representation theory
escnn.gspacesPure Pythondefines the Euclidean spaces and their symmetries
escnn.kernelsJaxsolves for spaces of equivariant convolution kernels
escnn.nnEquinoxcontains equivariant modules to build deep neural networks

TODOs

Priority

Nice to have

Getting Started

escnn_jax is easy to use since it provides a high level user interface which abstracts most intricacies of group and representation theory away. The following code snippet shows how to perform an equivariant convolution from an RGB-image to 10 regular feature fields (corresponding to a group convolution).

from escnn_jax import gspaces                                          #  1
from escnn_jax import nn                                               #  2
import jax                                                             #  3
key = jax.random.PRNGKey(0)                                            #  4
key1, key2 = jax.random.split(key, 2)                                  #  5
                                                                       #  6
r2_act = gspaces.rot2dOnR2(N=8)                                        #  7
feat_type_in  = nn.FieldType(r2_act,  3*[r2_act.trivial_repr])         #  8
feat_type_out = nn.FieldType(r2_act, 10*[r2_act.regular_repr])         #  9
                                                                       # 10
conv = nn.R2Conv(feat_type_in, feat_type_out, kernel_size=5, key=key1) # 11
relu = nn.ReLU(feat_type_out)                                          # 12
                                                                       # 13
x = jax.random.normal(key2, (16, 3, 32, 32))                           # 14
x = feat_type_in(x)                                                    # 15
                                                                       # 16
y = relu(conv(x))                                                      # 17

Dependencies

The library is based on Python3.7

jax
equinox
jaxtyping
numpy
scipy
lie_learn
joblibx
py3nj

Optional:

pymanopt>=1.0.0
optax
chex

WARNING: py3nj enables a fast computation of Clebsh Gordan coefficients. If this package is not installed, our library relies on a numerical method to estimate them. This numerical method is not guaranteed to return the same coefficients computed by py3nj (they can differ by a sign). For this reason, models built with and without py3nj might not be compatible.

To successfully install py3nj you may need a Fortran compiler installed in you environment.

Installation

You can install the latest release as

pip install escnn_jax

or you can clone this repository and manually install it with

pip install git+https://github.com/QUVA-Lab/escnn_jax

Contributing

Would you like to contribute to escnn_jax? That's great!

Then, check the instructions in CONTRIBUTING.md and help us to improve the library!

Cite

The development of this library was part of the work done for our papers A Program to Build E(N)-Equivariant Steerable CNNs and General E(2)-Equivariant Steerable CNNs. Please cite these works if you use our code:


   @inproceedings{cesa2022a,
        title={A Program to Build {E(N)}-Equivariant Steerable {CNN}s },
        author={Gabriele Cesa and Leon Lang and Maurice Weiler},
        booktitle={International Conference on Learning Representations},
        year={2022},
        url={https://openreview.net/forum?id=WE4qe9xlnQw}
    }
    
   @inproceedings{e2cnn,
       title={{General E(2)-Equivariant Steerable CNNs}},
       author={Weiler, Maurice and Cesa, Gabriele},
       booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
       year={2019},
   }

Feel free to contact us.

License

escnn_jax is distributed under BSD Clear license. See LICENSE file.