Awesome
<div align="center"> <img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img> </div>:red_circle: Non-official experimental :red_circle: JAX on Graphcore IPU
Install guide | Quickstart | IPU JAX on Paperspace | Documentation
:red_circle: :warning: Non-official experimental :warning: :red_circle:
This is a very thin fork of http://github.com/google/jax for Graphcore IPU. This package is provided by Graphcore Research for experimentation purposes only, not production (inference or training).
Features and limitations of experimental JAX on IPUs
The following features are supported:
- Vanilla JAX API: no additional IPU specific API, any code written for IPUs is backward compatible with other backends (CPU/GPU/TPU);
- JAX asynchronous dispatch on IPU backend;
- Multiple IPUs with collectives using
pmap
and (experimental)pjit
; - Large coverage of JAX lax operators;
- Support of JAX buffer donation to keep parameters on IPU SRAM;
Known limitations of the project:
- No eager mode (every JAX call has to be compiled, loaded, and finally executed on IPU device);
- IPU code generated can be larger than official Graphcore TensorFlow or PopTorch (limiting batch size or model size);
- Multi-IPUs collective have topology restrictions (following Graphcore GCL API);
- Missing linear algebra operators;
- Incomplete support of JAX random number generation on IPU device;
- Deactivated support of JAX infeeds and outfeeds;
This is a research project, not an official Graphcore product. Expect bugs and sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!
Installation
The experimental JAX wheels require Ubuntu 20.04, Graphcore Poplar SDK 3.1 or 3.2 and Python 3.8, and can be installed as following:
pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk310 -f https://graphcore-research.github.io/jax-experimental/wheels.html
For SDK 3.2, please change jaxlib
version to jaxlib==0.3.15+ipu.sdk320
.
Minimal example
The following example can be run on Graphcore IPU Paperspace (or on a non-IPU machine using the IPU emulator):
from functools import partial
import jax
import numpy as np
@partial(jax.jit, backend="ipu")
def ipu_function(data):
return data**2 + 1
data = np.array([1, -2, 3], np.float32)
output = ipu_function(data)
print(output, output.device())
JAX on IPU Paperspace notebooks
- JAX on IPU quickstart
- JAX
pmap
on IPUs quickstart - Stateful linear regression on IPU
- MNIST neural net training on IPU
- GNN training on IPU
- JAX
pjit
on IPUs quickstart
Additional JAX on IPU examples:
- JAX on IPU quickstart notebook;
- MNIST classifier training on IPU;
- MNIST classifier training on multiple IPUs using
pmap
;
Useful JAX backend flags:
As standard in JAX, these flags can be set using from jax.config import config
import.
Flag | Description |
---|---|
config.FLAGS.jax_platform_name ='ipu'/'cpu' | Configure default JAX backend. Useful for CPU initialization. |
config.FLAGS.jax_ipu_use_model = True | Use IPU model emulator. |
config.FLAGS.jax_ipu_model_num_tiles = 8 | Set the number of tiles in the IPU model. |
config.FLAGS.jax_ipu_device_count = 2 | Set the number of IPUs visible in JAX. Can be any local IPU available. |
config.FLAGS.jax_ipu_visible_devices = '0,1' | Set the specific collection of local IPUs to be visible in JAX. |
Alternatively, like other JAX flags, these can be set using environment variables (e.g. JAX_IPU_USE_MODEL
, JAX_IPU_MODEL_NUM_TILES
,...).
Useful PopVision environment variables:
- Generate PopVision Graph analyser profile:
POPLAR_ENGINE_OPTIONS='{"autoReport.all":"true", "debug.allowOutOfMemory":"true"}'
- Generate PopVision system analyser profile:
PVTI_OPTIONS='{"enable":"true", "directory":"./reports"}'
Documentation
- Performance tips for JAX on IPUs;
- How to build experimental JAX Python wheels for IPUs;
- Original JAX readme;
License
The project remains licensed under the Apache License 2.0, with the following files unchanged:
The additional dependencies introduced for Graphcore IPU support are: