Awesome
E(n)-equivariant Steerable CNNs (escnn)
Documentation | escnn library |
:rocket: ~20% faster than pytorch*
escnn_jax is a Jax port of the PyTorch escnn library for equivariant deep learning. escnn_jax supports steerable CNNs equivariant to both 2D and 3D isometries, as well as equivariant MLPs.
The library is structured into four subpackages with different high-level features:
Component | Dependency | Description |
---|---|---|
escnn.group | Pure Python | implements basic concepts of group and representation theory |
escnn.gspaces | Pure Python | defines the Euclidean spaces and their symmetries |
escnn.kernels | Jax | solves for spaces of equivariant convolution kernels |
escnn.nn | Equinox | contains equivariant modules to build deep neural networks |
TODOs
Priority
- reproduce examples and baselines
-
mlp.ipynb
- appart for
IIDBatchNorm1d
module
- appart for
-
introduction.ipynb
-
model.ipynb
-
octahedral_cnn.ipynb
-
- mimic
requires_grad=false
for 'buffer' variables to avoid including them inopt_state
andgrads
- added in
EquivariantModule
the methodsset_buffer
andget_buffer
which wrap the variable inlax.stop_gradient
- added in
EquivariantModule
the methodsset_parameter
andget_parameter
which wrap the Array a custom typeescn_jax.nn.ParameterArray
which can later be used to filter the parameters
- added in
- enhance
model.eval()
behaviour; makeEquivariantModule.eval
recursively call submodules? - speed up module's
__init__
e.g.nn.Linear
andnn.R2Conv
- speed up module's
__call__
if possible? - better
__repr__
forEquivariantModule
andeqx.nn.Module
more generally - make sure that tests pass for implemented modules and kernels
- Bug?
InnerBatchNorm.eval()
without training returns high values - add
export
method for layers - properly measuring speed up wrt pytorch version
Nice to have
- add support for
haiku
/flax
underescnn.nn.haiku
/escnn.nn.flax
-
jaxlinop
forRepresentation
class akin toemlp
, and more generally rewriteescnn_jax.group
injax
? - add missing modules cf
/nn/__init__.py
Getting Started
escnn_jax is easy to use since it provides a high level user interface which abstracts most intricacies of group and representation theory away. The following code snippet shows how to perform an equivariant convolution from an RGB-image to 10 regular feature fields (corresponding to a group convolution).
from escnn_jax import gspaces # 1
from escnn_jax import nn # 2
import jax # 3
key = jax.random.PRNGKey(0) # 4
key1, key2 = jax.random.split(key, 2) # 5
# 6
r2_act = gspaces.rot2dOnR2(N=8) # 7
feat_type_in = nn.FieldType(r2_act, 3*[r2_act.trivial_repr]) # 8
feat_type_out = nn.FieldType(r2_act, 10*[r2_act.regular_repr]) # 9
# 10
conv = nn.R2Conv(feat_type_in, feat_type_out, kernel_size=5, key=key1) # 11
relu = nn.ReLU(feat_type_out) # 12
# 13
x = jax.random.normal(key2, (16, 3, 32, 32)) # 14
x = feat_type_in(x) # 15
# 16
y = relu(conv(x)) # 17
Dependencies
The library is based on Python3.7
jax
equinox
jaxtyping
numpy
scipy
lie_learn
joblibx
py3nj
Optional:
pymanopt>=1.0.0
optax
chex
WARNING:
py3nj
enables a fast computation of Clebsh Gordan coefficients. If this package is not installed, our library relies on a numerical method to estimate them. This numerical method is not guaranteed to return the same coefficients computed bypy3nj
(they can differ by a sign). For this reason, models built with and withoutpy3nj
might not be compatible.
To successfully install
py3nj
you may need a Fortran compiler installed in you environment.
Installation
You can install the latest release as
pip install escnn_jax
or you can clone this repository and manually install it with
pip install git+https://github.com/QUVA-Lab/escnn_jax
Contributing
Would you like to contribute to escnn_jax? That's great!
Then, check the instructions in CONTRIBUTING.md and help us to improve the library!
Cite
The development of this library was part of the work done for our papers A Program to Build E(N)-Equivariant Steerable CNNs and General E(2)-Equivariant Steerable CNNs. Please cite these works if you use our code:
@inproceedings{cesa2022a,
title={A Program to Build {E(N)}-Equivariant Steerable {CNN}s },
author={Gabriele Cesa and Leon Lang and Maurice Weiler},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=WE4qe9xlnQw}
}
@inproceedings{e2cnn,
title={{General E(2)-Equivariant Steerable CNNs}},
author={Weiler, Maurice and Cesa, Gabriele},
booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
year={2019},
}
Feel free to contact us.
License
escnn_jax is distributed under BSD Clear license. See LICENSE file.