Awesome
JAX bindings to FINUFFT
This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library. Take a look at the FINUFFT docs for all the necessary definitions, conventions, and more information about the algorithms and their implementation. This package uses a low-level interface to directly expose the FINUFFT library to JAX's XLA backend, as well as implementing differentiation rules for the transforms.
Included features
This library includes CPU and GPU (CUDA) support. GPU support is implemented through the cuFINUFFT interface of the FINUFFT library.
Type 1 and 2 transforms
are supported in 1, 2, and 3 dimensions on the CPU, and 2 and 3 dimensions on the GPU.
All of these functions support forward, reverse, and higher-order differentiation,
as well as batching using vmap
.
[!NOTE] The GPU backend does not currently support 1D (#125).
Installation
The easiest ways to install jax-finufft is to install a pre-compiled binary from PyPI or conda-forge, but if you need GPU support or want to get tuned performance, you'll want to follow the instructions to install from source as described below.
Install binary from PyPI
[!NOTE] Only the CPU-enabled build of jax-finufft is available as a binary wheel on PyPI. For a GPU-enabled build, you'll need to build from source as described below.
To install a binary wheel from PyPI using pip, run the following commands:
python -m pip install "jax[cpu]"
python -m pip install jax-finufft
If this fails, you may need to use a conda-forge binary, or install from source.
Install binary from conda-forge
[!NOTE] Only the CPU-enabled build of jax-finufft is available as a binary from conda-forge. For a GPU-enabled build, you'll need to build from source as described below.
To install using mamba (or conda), run:
mamba install -c conda-forge jax-finufft
Install from source
Dependencies
Unsurprisingly, a key dependency is JAX, which can be installed following the directions in the JAX documentation. If you're going to want to run on a GPU, make sure that you install the appropriate JAX build.
The non-Python dependencies that you'll need are:
Older versions of CUDA may work, but they are untested.
Below we provide some example workflows for installing the required dependencies:
<details> <summary>Install CPU dependencies with mamba or conda</summary>mamba create -n jax-finufft -c conda-forge python jax fftw cxx-compiler
mamba activate jax-finufft
</details>
<details>
<summary>Install GPU dependencies with mamba or conda</summary>
For a GPU build, while the CUDA libraries and compiler are nominally available through conda, our experience trying to install them this way suggests that the "traditional" way of obtaining the CUDA Toolkit directly from NVIDIA may work best (see related advice for Horovod). After installing the CUDA Toolkit, one can set up the rest of the dependencies with:
mamba create -n gpu-jax-finufft -c conda-forge python numpy scipy fftw 'gxx<12'
mamba activate gpu-jax-finufft
export CMAKE_PREFIX_PATH=$CONDA_PREFIX:$CMAKE_PREFIX_PATH
python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Other ways of installing JAX are given on the JAX website; the "local CUDA" install methods are preferred for jax-finufft as this ensures the CUDA extensions are compiled with the same Toolkit version as the CUDA runtime. However, this is not required as long as both JAX and jax-finufft use CUDA with the same major version.
</details> <details> <summary>Install GPU dependencies using Flatiron module system</summary>ml modules/2.3 \
gcc \
python/3.11 \
fftw \
cuda/12
export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"
</details>
Configuring the build
There are several important CMake variables that control aspects of the jax-finufft and (cu)finufft builds. These include:
JAX_FINUFFT_USE_CUDA
[disabled by default]: build with GPU supportCMAKE_CUDA_ARCHITECTURES
[defaultnative
]: the target GPU architecture.native
means the GPU arch of the build system.FINUFFT_ARCH_FLAGS
[default-march=native
]: the target CPU architecture. The default is the native CPU arch of the build system.
Each of these can be set as -Ccmake.define.NAME=VALUE
arguments to pip install
. For example,
to build with GPU support from the repo root, run:
pip install -Ccmake.define.JAX_FINUFFT_USE_CUDA=ON .
Use multiple -C
arguments to set multiple variables. The -C
argument will work with any of the source installation methods (e.g. PyPI source dist, GitHub, etc).
Build options can also be set with the CMAKE_ARGS
environment variable. For example:
export CMAKE_ARGS="$CMAKE_ARGS -DJAX_FINUFFT_USE_CUDA=ON"
GPU build configuration
Building with GPU support requires passing JAX_FINUFFT_USE_CUDA=ON
to CMake. See Configuring the build.
By default, jax-finufft will build for the GPU of the build machine. If you need to target
a different compute capability, such as 8.0 for Ampere, set CMAKE_CUDA_ARCHITECTURES
as a CMake define:
pip install -Ccmake.define.JAX_FINUFFT_USE_CUDA=ON -Ccmake.define.CMAKE_CUDA_ARCHITECTURES=80 .
CMAKE_CUDA_ARCHITECTURES
also takes a semicolon-separated list.
To detect the arch for a specific GPU, one can run:
$ nvidia-smi --query-gpu=compute_cap --format=csv,noheader
8.0
The values are also listed on the NVIDIA website.
In some cases, you may also need the following at runtime:
export LD_LIBRARY_PATH="$CUDA_PATH/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
If CUDA_PATH
isn't set, you'll need to replace it with the path to your CUDA
installation in the above line, often something like /usr/local/cuda
.
Install source from PyPI
The source code for all released versions of jax-finufft are available on PyPI, and this can be installed using:
python -m pip install --no-binary jax-finufft
Install source from GitHub
Alternatively, you can check out the source repository from GitHub:
git clone --recurse-submodules https://github.com/flatironinstitute/jax-finufft
cd jax-finufft
[!NOTE] Don't forget the
--recurse-submodules
argument when cloning the repo because the upstream FINUFFT library is included as a git submodule. If you do forget, you can rungit submodule update --init --recursive
in your local copy to checkout the submodule after the initial clone.
After cloning the repository, you can install the local copy using:
python -m pip install -e .
where the -e
flag optionally runs an "editable" install.
As yet another alternative, the latest development version from GitHub can be installed directly (i.e. without cloning first) with
python -m pip install git+https://github.com/flatironinstitute/jax-finufft.git
Usage
This library provides two high-level functions (and these should be all that you
generally need to interact with): nufft1
and nufft2
(for the two "types" of
transforms). If you're already familiar with the Python
interface to FINUFFT,
please note that the function signatures here are different!
For example, here's how you can do a 1-dimensional type 1 transform (only works on CPU):
import numpy as np
from jax_finufft import nufft1
M = 100000
N = 200000
x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)
Noting that the eps
and iflag
are optional, and that (for good reason, I
promise!) the order of the positional arguments is reversed from the finufft
Python package.
The syntax for a 2-, or 3-dimensional transform (CPU or GPU) is:
f = nufft1((Nx, Ny), c, x, y) # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z) # 3D
The syntax for a type 2 transform is (also allowing optional iflag
and eps
parameters):
c = nufft2(f, x) # 1D
c = nufft2(f, x, y) # 2D
c = nufft2(f, x, y, z) # 3D
All of these functions support batching using vmap
, and forward and reverse
mode differentiation.
Selecting a platform
If you compiled jax-finufft with GPU support, you can force it to use a particular
backend by setting the environment variable JAX_PLATFORMS=cpu
or JAX_PLATFORMS=cuda
.
Advanced usage
The tuning parameters for the library can be set using the opts
parameter to
nufft1
and nufft2
. For example, to explicitly set the CPU up-sampling
factor that FINUFFT should
use, you can update the example from above as follows:
from jax_finufft import options
opts = options.Opts(upsampfac=2.0)
nufft1(N, c, x, opts=opts)
The corresponding option for the GPU is gpu_upsampfac
. In fact, all options
for the GPU are prefixed with gpu_
.
One complication here is that the vector-Jacobian
product
for a NUFFT requires evaluating a NUFFT of a different type. This means that you
might want to separately tune the options for the forward and backward pass.
This can be achieved using the options.NestedOpts
interface. For example, to
use a different up-sampling factor for the forward and backward passes, the code
from above becomes:
import jax
opts = options.NestedOpts(
forward=options.Opts(upsampfac=2.0),
backward=options.Opts(upsampfac=1.25),
)
jax.grad(lambda args: nufft1(N, *args, opts=opts).real.sum())((c, x))
or, in this case equivalently:
opts = options.NestedOpts(
type1=options.Opts(upsampfac=2.0),
type2=options.Opts(upsampfac=1.25),
)
See the FINUFFT docs for
descriptions of all the CPU tuning parameters. The corresponding GPU parameters
are currently only listed in source code form in
cufinufft_opts.h
.
Similar libraries
- finufft: The "official" Python bindings to FINUFFT. A good choice if you're not already using JAX and if you don't need to differentiate through your transform.
- mrphys/tensorflow-nufft: TensorFlow bindings for FINUFFT and cuFINUFFT.
License & attribution
This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright:
Copyright 2021-2024 The Simons Foundation, Inc.
If you use this software, please cite the primary references listed on the FINUFFT docs.