Unxt is unitful quantities and calculations in JAX, built on Equinox and Quax.

Unxt supports JAX's compelling features:

And best of all, unxt doesn't force you to use special unit-compatible re-exports of JAX libraries. You can use unxt with existing JAX code, and with quax's simple decorator, JAX will work with unxt.Quantity.


pip install unxt
<details> <summary>using <code>uv</code></summary>
uv add unxt
</details> <details> <summary>from source, using pip</summary>
pip install git+https://https://github.com/GalacticDynamics/unxt.git
</details> <details> <summary>building from source</summary>
cd /path/to/parent
git clone https://https://github.com/GalacticDynamics/unxt.git
cd unxt
pip install -e .  # editable mode


Read The Docs

Quick example

import unxt as u

x = u.Quantity(jnp.arange(1, 5, dtype=float), "km")
# Quantity['length'](Array([1., 2., 3., 4.], dtype=float64), unit='km')

The constituent value and unit are accessible as attributes:

# Array([1., 2., 3., 4.], dtype=float64)

# Unit("m")

Quantity objects obey the rules of unitful arithmetic.

# Addition / Subtraction
print(x + x)
# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='km')

# Multiplication / Division
print(2 * x)
# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='km')

y = u.Quantity(jnp.arange(4, 8, dtype=float), "yr")

print(x / y)
# Quantity['speed'](Array([0.25      , 0.4       , 0.5       , 0.57142857], dtype=float64), unit='km / yr')

# Exponentiation
# Quantity['area'](Array([0., 1., 4., 9.], dtype=float64), unit='km2')

# Unit checking on operations
    x + y
except Exception as e:
# 'yr' (time) and 'km' (length) are not convertible

Quantities can be converted to different units:

print(u.uconvert("m", x))  # via function
# Quantity['length'](Array([1000., 2000., 3000., 4000.], dtype=float64), unit='m')

print(x.uconvert("m"))  # via method
# Quantity['length'](Array([1000., 2000., 3000., 4000.], dtype=float64), unit='m')

Since Quantity is parametric, it can do runtime dimension checking!

LengthQuantity = u.Quantity["length"]
print(LengthQuantity(2, "km"))
# Quantity['length'](Array(2, dtype=int64, weak_type=True), unit='km')

    LengthQuantity(2, "s")
except ValueError as e:
# Physical type mismatch.

unxt is built on quax, which enables custom array-ish objects in JAX. For convenience we use the quaxed library, which is just a quax.quaxify wrapper around jax to avoid boilerplate code.


Using quaxed is optional. You can directly use quaxify, and even apply it to the top-level function instead of individual functions.

from quaxed import grad, vmap
import quaxed.numpy as jnp

# Quantity['area'](Array([ 1.,  4.,  9., 16.], dtype=float64), unit='km2')

print(qnp.power(x, 3))
# Quantity['volume'](Array([ 1.,  8., 27., 64.], dtype=float64), unit='km3')

print(vmap(grad(lambda x: x**3))(x))
# Quantity['area'](Array([ 3., 12., 27., 48.], dtype=float64), unit='km2')

See the documentation for more examples and details of JIT and AD



If you found this library to be useful and want to support the development and maintenance of lower-level code libraries for the scientific community, please consider citing this work.


We welcome contributions!

