Home

Awesome

JAX Scalify: end-to-end scaled arithmetic

tests PyPI version license GitHub Repo stars

<!-- [![codecov](https://codecov.io/gh/jax-scalify/branch/main/graph/badge.svg?token=bHOkKY5Fze)](https://codecov.io/gh/jax-scalify) -->

Installation | Quickstart | Documentation

📣 Scalify has been accepted to ICML 2024 workshop WANT! 📣

JAX Scalify is a library implementing end-to-end scale propagation and scaled arithmetic, allowing easy training and inference of deep neural networks in low precision (BF16, FP16, FP8).

Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Most of these works focus on ad-hoc approaches around scaling of matrix multiplications (and sometimes reduction operations). Scalify is adopting a more systematic approach with end-to-end scale propagation, i.e. transforming the full computational graph into a ScaledArray graph where every operation has ScaledArray inputs and returns ScaledArray:

@dataclass
class ScaledArray:
    # Main data component, in low precision.
    data: Array
    # Scale, usually scalar, in FP32 or E8M0.
    scale: Array

    def __array__(self) -> Array:
        # Tensor represented as a `ScaledArray`.
        return data * scale.astype(self.data.dtype)

The main benefits of the scalify approach are:

Installation

JAX Scalify can be directly installed from PyPi:

pip install jax-scalify

Please follow JAX documentation for a proper JAX installation on GPU/TPU.

The latest version of JAX Scalify is available directly from Github:

pip install git+https://github.com/graphcore-research/jax-scalify.git

Quickstart

A typical JAX training loop just requires a couple of modifications to take advantage of scalify. More specifically:

The following (simplified) example presents how to scalify can be incorporated into a JAX training loop.

import jax_scalify as jsa

# Scalify transform on FWD + BWD + optimizer.
# Propagating scale in the computational graph.
@jsa.scalify
def update(state, data, labels):
    # Forward and backward pass on the NN model.
    loss, grads =
        jax.grad(model)(state, data, labels)
    # Optimizer applied on scaled state.
    state = optimizer.apply(state, grads)
    return loss, state

# Model + optimizer state.
state = (model.init(...), optimizer.init(...))
# Transform state to scaled array(s)
sc_state = jsa.as_scaled_array(state)

for (data, labels) in dataset:
    # If necessary (e.g. images), scale input data.
    data = jsa.as_scaled_array(data)
    # State update, with full scale propagation.
    sc_state = update(sc_state, data, labels)
    # Optional dynamic rescaling of state.
    sc_state = jsa.ops.dynamic_rescale_l2(sc_state)

As presented in the code above, the model state is represented as a JAX PyTree of ScaledArray, propagated end-to-end through the model (forward and backward passes) as well as the optimizer.

A full collection of examples is available:

Documentation

Development

For a local development setup, we recommend an interactive install:

git clone git@github.com:graphcore-research/jax-scalify.git
pip install -e ./

Running pre-commit and pytest on the JAX Scalify repository:

pip install pre-commit
pre-commit run --all-files
pytest -v ./tests

Python wheel can be built with the usual command python -m build.