Home

Awesome

Dataclass Array

Unittests PyPI version Documentation Status

DataclassArray are dataclasses which behave like numpy-like arrays (can be batched, reshaped, sliced,...), compatible with Jax, TensorFlow, and numpy (with torch support planned).

This reduce boilerplate and improve readability. See the motivating examples section bellow.

To view an example of dataclass arrays used in practice, see visu3d.

Documentation

Definition

To create a dca.DataclassArray, take a frozen dataclass and:

import dataclass_array as dca
from dataclass_array.typing import FloatArray


class Ray(dca.DataclassArray):
  pos: FloatArray['*batch_shape 3']
  dir: FloatArray['*batch_shape 3']

Usage

Afterwards, the dataclass can be used as a numpy array:

ray = Ray(pos=jnp.zeros((3, 3)), dir=jnp.eye(3))


ray.shape == (3,)  # 3 rays batched together
ray.pos.shape == (3, 3)  # Individual fields still available

# Numpy slicing/indexing/masking
ray = ray[..., 1:2]
ray = ray[norm(ray.dir) > 1e-7]

# Shape transformation
ray = ray.reshape((1, 3))
ray = ray.reshape('h w -> w h')  # Native einops support
ray = ray.flatten()

# Stack multiple dataclass arrays together
ray = dca.stack([ray0, ray1, ...])

# Supports TF, Jax, Numpy (torch planned) and can be easily converted
ray = ray.as_jax()  # as_np(), as_tf()
ray.xnp == jax.numpy  # `numpy`, `jax.numpy`, `tf.experimental.numpy`

# Compatibility `with jax.tree_util`, `jax.vmap`,..
ray = jax.tree_util.tree_map(lambda x: x+1, ray)

A DataclassArray has 2 types of fields:

class MyArray(dca.DataclassArray):
  # Array fields
  a: FloatArray['*batch_shape 3']  # Defined by `etils.array_types`
  b: FloatArray['*batch_shape _ _']  # Dynamic shape
  c: Ray  # Nested DataclassArray (equivalent to `Ray['*batch_shape']`)
  d: Ray['*batch_shape 6']

  # Array fields explicitly defined
  e: Any = dca.field(shape=(3,), dtype=np.float32)
  f: Any = dca.field(shape=(None,  None), dtype=np.float32)  # Dynamic shape
  g: Ray = dca.field(shape=(3,), dtype=Ray)  # Nested DataclassArray

  # Static field (everything not defined as above)
  static0: float
  static1: np.array

Vectorization

@dca.vectorize_method allow your dataclass method to automatically support batching:

  1. Implement method as if self.shape == ()
  2. Decorate the method with dca.vectorize_method
class Camera(dca.DataclassArray):
  K: FloatArray['*batch_shape 4 4']
  resolution = tuple[int, int]

  @dca.vectorize_method
  def rays(self) -> Ray:
    # Inside `@dca.vectorize_method` shape is always guarantee to be `()`
    assert self.shape == ()
    assert self.K.shape == (4, 4)

    # Compute the ray as if there was only a single camera
    return Ray(pos=..., dir=...)

Afterward, we can generate rays for multiple camera batched together:

cams = Camera(K=K)  # K.shape == (num_cams, 4, 4)
rays = cams.rays()  # Generate the rays for all the cameras

cams.shape == (num_cams,)
rays.shape == (num_cams, h, w)

@dca.vectorize_method is similar to jax.vmap but:

For example, with __matmul__(self, x: T) -> T:

() @ (*x,) -> (*x,)
(b,) @ (b, *x) -> (b, *x)
(b,) @ (1, *x) -> (b, *x)
(1,) @ (b, *x) -> (b, *x)
(b, h, w) @ (b, h, w, *x) -> (b, h, w, *x)
(1, h, w) @ (b, 1, 1, *x) -> (b, h, w, *x)
(a, *x) @ (b, *x) -> Error: Incompatible a != b

To test on Colab, see the visu3d dataclass Colab tutorial.

Motivating examples

dca.DataclassArray improve readability by simplifying common patterns:

Installation

pip install dataclass_array

This is not an official Google product