Awesome
BLOOM 🌸 Inference in JAX
Structure
CPU Host: as defined in TPU manager
TPU Host: as defined in Host worker
ray
: distributes load from CPU host -> TPU hosts
Example usage: run.py
Setting Up a TPU-Manager
The TPU hosts are managed by a single TPU manager. This TPU manager takes the form of a single CPU device.
First, create a CPU VM in the same region as that of the TPU pod. This is important to enable the TPU manager to communicate with the TPU hosts. A suitable device config is as follows:
- Region & Zone: TO MATCH TPU ZONE
- Machine type: c2-standard-8
- CPU platform: Intel Cascade Lake
- Boot disk: 256GB balanced persistent disk
SSH into the CPU and set-up a Python environment with the same Python version as that of the TPUs. The default TPU Python version is 3.8.10. You should ensure the Python version of th CPU matches this.
python3.8 -m venv /path/to/venv
If the above does not work, run the following and then repeat:
sudo apt-get update
sudo apt-get install python3-venv
Activate Python env:
source /path/to/venv/bin/activate
Check Python version is 3.8.10:
python --version
Clone the repository and install requirements:
git clone https://github.com/huggingface/bloom-jax-inference.git
cd bloom-jax-inference
pip install -r requirements.txt
Authenticate gcloud
, which will require copy-and-pasting a command into a terminal window on a machine with a browser installed:
gcloud auth login
Now SSH into one of the workers. This will generate an SSH key:
gcloud alpha compute tpus tpu-vm ssh patrick-tpu-v3-32 --zone europe-west4-a --worker 0
Logout of the TPU worker:
logout
You should now be back in the CPU host.