Awesome
<div align="center"> <img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img> <h1>TessellateIPU Library</h1> </div> <!-- [![codecov](https://codecov.io/gh/datamol-io/graphium/branch/main/graph/badge.svg?token=bHOkKY5Fze)](https://codecov.io/gh/datamol-io/graphium) -->Features | Installation guide | Quickstart | Documentation | Projects
:red_circle: :warning: Non-official Graphcore Product :warning: :red_circle:
TessellateIPU is a library bringing low-level Poplar IPU programming to Python ML frameworks (JAX at the moment, and PyTorch in the near future).
The package is maintained by the Graphcore Research team. Expect bugs and sharp edges! Please let us know what you think!
Features
TessellateIPU brings low-level Poplar IPU programming to Python, while being fully compatible with ML framework standard APIs. The main features are:
- Control tile mapping of arrays using
tile_put_replicated
ortile_put_sharded
- Support of standard JAX LAX operations at tile level using
tile_map
(see operations supported) - Easy integration of custom IPU C++ vertex (see vertex example)
- Access to low-level IPU hardware functionalities such as cycle count and random seed set/get
- Full compatibility with other backends
The TessellateIPU API allows easy and efficient implementation of algorithms on IPUs, while keeping compatibility with other backends (CPU, GPU, TPU). For more details on the API, please refer to the TessellateIPU documentation, or try it on IPU Paperspace Gradient .
Installation guide
This package requires JAX IPU experimental (available for Python 3.8 and Poplar SDK versions 3.1 or 3.2). For Poplar SDK 3.2:
pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html
Please change sdk320
into sdk310
if using Poplar SDK 3.1.
As a pure Python repo, TessellateIPU can then be directly installed from GitHub using pip
:
pip install git+https://github.com/graphcore-research/tessellate-ipu.git@main
Note: main
can be replaced with any tag (v0.1
, ...) or commit hash in order to install a specific version.
Local pip install is also supported after cloning the Github repository:
git clone git@github.com:graphcore-research/tessellate-ipu.git
pip install ./tessellate_ipu
Minimal example
The following is a simple example showing how to set the tile mapping of JAX arrays, and run a JAX LAX operation on these tiles.
import numpy as np
import jax
from tessellate_ipu import tile_put_sharded, tile_map
# Which IPU tiles do we want to use?
tiles = (0, 1, 3)
@jax.jit
def compute_fn(data0, data1):
# Tile sharding arrays along the first axis.
input0 = tile_put_sharded(data0, tiles)
input1 = tile_put_sharded(data1, tiles)
# Map a JAX LAX primitive on tiles.
output = tile_map(jax.lax.add_p, input0, input1)
return output
data = np.random.rand(len(tiles), 2, 3).astype(np.float32)
output = compute_fn(data, 3 * data)
print("Output:", output)
Useful environment variables and flags
JAX IPU experimental flags, using from jax.config import config
:
Flag | Description |
---|---|
config.FLAGS.jax_platform_name ='ipu'/'cpu' | Configure default JAX backend. Useful for CPU initialization. |
config.FLAGS.jax_ipu_use_model = True | Use IPU model emulator. |
config.FLAGS.jax_ipu_model_num_tiles = 8 | Set the number of tiles in the IPU model. |
config.FLAGS.jax_ipu_device_count = 2 | Set the number of IPUs visible in JAX. Can be any local IPU available. |
config.FLAGS.jax_ipu_visible_devices = '0,1' | Set the specific collection of local IPUs to be visible in JAX. |
Alternatively, like other JAX flags, these can be set using environment variables (for example JAX_IPU_USE_MODEL
and JAX_IPU_MODEL_NUM_TILES
).
PopVision environment variables:
- Generate a PopVision Graph analyser profile:
PVTI_OPTIONS='{"enable":"true", "directory":"./reports"}'
- Generate a PopVision system analyser profile:
POPLAR_ENGINE_OPTIONS='{"autoReport.all":"true", "debug.allowOutOfMemory":"true"}'
Documentation
Projects using TessellateIPU
- PySCF IPU: Molecular quantum chemistry simulation on Graphcore IPUs;
License
Copyright (c) 2023 Graphcore Ltd. The project is licensed under the Apache License 2.0.
TessellateIPU is implemented using C++ custom operations. These have the following C++ libraries as dependencies, statically compiled into a shared library:
Component | Description | License |
---|---|---|
fastbase64 | Base64 fast decoder library | Simplified BSD (FreeBSD) License |
fmt | A modern C++ formatting library | MIT license |
half | IEEE-754 conformant half-precision library | MIT license |
json | JSON for modern C++ | MIT license |
nanobind | Tiny C++/Python bindings | BSD 3-Clause License |