Home

Awesome

The annotated MNIST image classification example with Flax Linen and Optax

UPDATE: Use the up-to-date Flax Quickstart on the official Flax site.


Author: @8bitmp3

This tutorial uses Flax—a high-performance deep learning library for JAX designed for flexibility—to show you how to construct a simple convolutional neural network (CNN) using the Linen API and Optax and train the network for image classification on the MNIST dataset.

If you're new to JAX, check out:

To learn more about Flax and its Linen API, refer to:

This tutorial has the following workflow:

If you're using Google Colaboratory (Colab), enable the GPU acceleration (Runtime > Change runtime type > Hardware accelerator:GPU).

Setup

  1. Install JAX, Flax, Optax, and TensorFlow Datasets (TFDS). Flax can use any data-loading pipeline and this example demonstrates how to utilize TFDS.
!pip install --upgrade -q pip jax jaxlib flax optax tensorflow-datasets
  1. Import JAX, JAX NumPy (which lets you run code on GPUs and TPUs), Flax, ordinary NumPy, and TFDS.
import jax
import jax.numpy as jnp               # JAX NumPy

from flax import linen as nn          # The Linen API
from flax.training import train_state
import optax                          # The Optax gradient processing and optimization library

import numpy as np                    # Ordinary NumPy
import tensorflow_datasets as tfds    # TFDS for MNIST

Build a model

Build a convolutional neural network with the Flax Linen API by subclassing flax.linen.Module. Because the architecture in this example is relatively simple—you're just stacking layers—you can define the inlined submodules directly within the __call__ method and wrap it with the @compact decorator (flax.linen.compact).

class CNN(nn.Module):

  @nn.compact
  # Provide a constructor to register a new parameter 
  # and return its initial value
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1)) # Flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)    # There are 10 classes in MNIST
    return x

Create a metrics function

For loss and accuracy metrics, create a separate function:

def compute_metrics(logits, labels):
  loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10)))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }
  return metrics

The dataset

Define a function that:

def get_datasets():
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  # Split into training/test sets
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  # Convert to floating-points
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
  return train_ds, test_ds

Training and evaluation functions

  1. Write a training step function that:

Use JAX's @jit decorator to trace the entire train_step function and just-in-time(JIT-compile with XLA into fused device operations that run faster and more efficiently on hardware accelerators.

@jax.jit
def train_step(state, batch):
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = jnp.mean(optax.softmax_cross_entropy(
        logits=logits, 
        labels=jax.nn.one_hot(batch['label'], num_classes=10)))
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits, batch['label'])
  return state, metrics
  1. Create a jit-compiled function that evaluates the model on the test set using flax.linen.Module.apply:
@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits, batch['label'])
  1. Define a training function for one epoch that:
def train_epoch(state, train_ds, batch_size, epoch, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # Skip an incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  batch_metrics = []

  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)

  training_batch_metrics = jax.device_get(batch_metrics)
  training_epoch_metrics = {
      k: np.mean([metrics[k] for metrics in training_batch_metrics])
      for k in training_batch_metrics[0]}

  print('Training - epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, training_epoch_metrics['loss'], training_epoch_metrics['accuracy'] * 100))

  return state, training_epoch_metrics
  1. Create a model evaluation function that:
def eval_model(model, test_ds):
  metrics = eval_step(model, test_ds)
  metrics = jax.device_get(metrics)
  eval_summary = jax.tree_map(lambda x: x.item(), metrics)
  return eval_summary['loss'], eval_summary['accuracy']

Load the dataset

Download the dataset and preprocess it with get_datasets you defined earlier:

train_ds, test_ds = get_datasets()

Initialize the parameters with PRNGs and instantiate the optimizer

  1. PRNGs: Before you start training the model, you need to randomly initialize the parameters.

In NumPy, you would usually use the stateful pseudorandom number generators (PRNG).

JAX, however, uses an explicit PRNG (refer to JAX - the sharp bits for details):

Note that in JAX and Flax you can have separate PRNG chains (with different names, such as rng and init_rng below) inside Modules for different applications. (Learn more about PRNG chains and JAX PRNG design.)

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
  1. Instantiate the CNN model and initialize its parameters using a PRNG:
cnn = CNN()
params = cnn.init(init_rng, jnp.ones([1, 28, 28, 1]))['params']
  1. Instantiate the SGD optimizer with Optax:
nesterov_momentum = 0.9
learning_rate = 0.001
tx = optax.sgd(learning_rate=learning_rate, nesterov=nesterov_momentum)
  1. Create a TrainState data class that applies the gradients and updates the optimizer state and parameters.
state = train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

Train the network and evaluate it

  1. Set the default number of epochs and the size of each batch:
num_epochs = 10
batch_size = 32
  1. Finally, begin training and evaluating the model over 10 epochs:
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state, train_metrics = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print('Testing - epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))
Training - epoch: 1, loss: 1.7941, accuracy: 62.73
Testing - epoch: 1, loss: 0.93, accuracy: 82.31
Training - epoch: 2, loss: 0.6114, accuracy: 85.10
Testing - epoch: 2, loss: 0.44, accuracy: 88.47
Training - epoch: 3, loss: 0.4128, accuracy: 88.40
Testing - epoch: 3, loss: 0.36, accuracy: 89.89
Training - epoch: 4, loss: 0.3598, accuracy: 89.67
Testing - epoch: 4, loss: 0.32, accuracy: 90.81
Training - epoch: 5, loss: 0.3280, accuracy: 90.50
Testing - epoch: 5, loss: 0.30, accuracy: 91.54
Training - epoch: 6, loss: 0.3047, accuracy: 91.18
Testing - epoch: 6, loss: 0.28, accuracy: 91.94
Training - epoch: 7, loss: 0.2853, accuracy: 91.71
Testing - epoch: 7, loss: 0.26, accuracy: 92.26
Training - epoch: 8, loss: 0.2680, accuracy: 92.15
Testing - epoch: 8, loss: 0.24, accuracy: 92.90
Training - epoch: 9, loss: 0.2522, accuracy: 92.72
Testing - epoch: 9, loss: 0.23, accuracy: 93.15
Training - epoch: 10, loss: 0.2384, accuracy: 92.99
Testing - epoch: 10, loss: 0.22, accuracy: 93.56