Home

Awesome

<div align="center"> <img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img> <h1>TessellateIPU Library</h1> </div>

Run on Gradient tests notebook-tests license GitHub Repo stars

<!-- [![codecov](https://codecov.io/gh/datamol-io/graphium/branch/main/graph/badge.svg?token=bHOkKY5Fze)](https://codecov.io/gh/datamol-io/graphium) -->

Features | Installation guide | Quickstart | Documentation | Projects

:red_circle: :warning: Non-official Graphcore Product :warning: :red_circle:

TessellateIPU is a library bringing low-level Poplar IPU programming to Python ML frameworks (JAX at the moment, and PyTorch in the near future).

The package is maintained by the Graphcore Research team. Expect bugs and sharp edges! Please let us know what you think!

Features

TessellateIPU brings low-level Poplar IPU programming to Python, while being fully compatible with ML framework standard APIs. The main features are:

The TessellateIPU API allows easy and efficient implementation of algorithms on IPUs, while keeping compatibility with other backends (CPU, GPU, TPU). For more details on the API, please refer to the TessellateIPU documentation, or try it on IPU Paperspace Gradient .

Installation guide

This package requires JAX IPU experimental (available for Python 3.8 and Poplar SDK versions 3.1 or 3.2). For Poplar SDK 3.2:

pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html

Please change sdk320 into sdk310 if using Poplar SDK 3.1.

As a pure Python repo, TessellateIPU can then be directly installed from GitHub using pip:

pip install git+https://github.com/graphcore-research/tessellate-ipu.git@main

Note: main can be replaced with any tag (v0.1, ...) or commit hash in order to install a specific version.

Local pip install is also supported after cloning the Github repository:

git clone git@github.com:graphcore-research/tessellate-ipu.git
pip install ./tessellate_ipu

Minimal example

The following is a simple example showing how to set the tile mapping of JAX arrays, and run a JAX LAX operation on these tiles.

import numpy as np
import jax
from tessellate_ipu import tile_put_sharded, tile_map

# Which IPU tiles do we want to use?
tiles = (0, 1, 3)

@jax.jit
def compute_fn(data0, data1):
    # Tile sharding arrays along the first axis.
    input0 = tile_put_sharded(data0, tiles)
    input1 = tile_put_sharded(data1, tiles)
    # Map a JAX LAX primitive on tiles.
    output = tile_map(jax.lax.add_p, input0, input1)
    return output

data = np.random.rand(len(tiles), 2, 3).astype(np.float32)
output = compute_fn(data, 3 * data)

print("Output:", output)

Useful environment variables and flags

JAX IPU experimental flags, using from jax.config import config:

FlagDescription
config.FLAGS.jax_platform_name ='ipu'/'cpu'Configure default JAX backend. Useful for CPU initialization.
config.FLAGS.jax_ipu_use_model = TrueUse IPU model emulator.
config.FLAGS.jax_ipu_model_num_tiles = 8Set the number of tiles in the IPU model.
config.FLAGS.jax_ipu_device_count = 2Set the number of IPUs visible in JAX. Can be any local IPU available.
config.FLAGS.jax_ipu_visible_devices = '0,1'Set the specific collection of local IPUs to be visible in JAX.

Alternatively, like other JAX flags, these can be set using environment variables (for example JAX_IPU_USE_MODEL and JAX_IPU_MODEL_NUM_TILES).

PopVision environment variables:

Documentation

Projects using TessellateIPU

License

Copyright (c) 2023 Graphcore Ltd. The project is licensed under the Apache License 2.0.

TessellateIPU is implemented using C++ custom operations. These have the following C++ libraries as dependencies, statically compiled into a shared library:

ComponentDescriptionLicense
fastbase64Base64 fast decoder librarySimplified BSD (FreeBSD) License
fmtA modern C++ formatting libraryMIT license
halfIEEE-754 conformant half-precision libraryMIT license
jsonJSON for modern C++MIT license
nanobindTiny C++/Python bindingsBSD 3-Clause License