UPDATE: Use the up-to-date Flax Quickstart on the official Flax site.
Author: @8bitmp3
This tutorial uses Flax—a high-performance deep learning library for JAX designed for flexibility—to show you how to construct a simple convolutional neural network (CNN) using the Linen API and Optax and train the network for image classification on the MNIST dataset.
If you're new to JAX, check out:
To learn more about Flax and its Linen API, refer to:
- Flax basics
- Flax patterns: Managing state and parameters
- Linen design principles
- Linen introduction
- More notebooks (including a more concise version of this MNIST notebook by @andsteing)
This tutorial has the following workflow:
- Perform a quick setup
- Build a convolutional neural network model with the Linen API that classifies images
- Define a loss and accuracy metrics function
- Create a dataset function with TensorFlow Datasets
- Define training and evaluation functions
- Load the MNIST dataset
- Initialize the parameters with PRNGs and instantiate the optimizer with Optax
- Train the network and evaluate it
If you're using Google Colaboratory (Colab), enable the GPU acceleration (Runtime > Change runtime type > Hardware accelerator:GPU).
- Install JAX, Flax, Optax, and TensorFlow Datasets (TFDS). Flax can use any data-loading pipeline and this example demonstrates how to utilize TFDS.
!pip install --upgrade -q pip jax jaxlib flax optax tensorflow-datasets
- Import JAX, JAX NumPy (which lets you run code on GPUs and TPUs), Flax, ordinary NumPy, and TFDS.
import jax
import jax.numpy as jnp # JAX NumPy
from flax import linen as nn # The Linen API
from flax.training import train_state
import optax # The Optax gradient processing and optimization library
import numpy as np # Ordinary NumPy
import tensorflow_datasets as tfds # TFDS for MNIST
Build a convolutional neural network with the Flax Linen API by subclassing flax.linen.Module
. Because the architecture in this example is relatively simple—you're just stacking layers—you can define the inlined submodules directly within the __call__
method and wrap it with the @compact
decorator (flax.linen.compact
).
class CNN(nn.Module):
@nn.compact
# Provide a constructor to register a new parameter
# and return its initial value
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # Flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x) # There are 10 classes in MNIST
return x
For loss and accuracy metrics, create a separate function:
- Optax has a built-in softmax cross-entropy loss (
optax.softmax_cross_entropy
). You will be defining and computing the loss inside a training step function later as follows: - The labels can be one-hot encoded with
jax.nn.one_hot
, as demonstrated below.
def compute_metrics(logits, labels):
loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10)))
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy
}
return metrics
Define a function that:
- Uses TFDS to load and prepare the MNIST dataset; and
- Converts the samples to floating-point numbers.
def get_datasets():
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
# Split into training/test sets
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
# Convert to floating-points
train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
return train_ds, test_ds
- Write a training step function that:
- Evaluates the neural network given the parameters and a batch of input images with the
flax.linen.Module.apply
method. - Defines and computes the
cross_entropy_loss
function. - Evaluates the loss function and its gradient using
jax.value_and_grad
(check the JAX autodiff cookbook to learn more). - Applies a pytree of gradients (
flax.training.train_state.TrainState.apply_gradients
) to the optimizer to update the model's parameters. - Returns the optimizer
state
and computes the metrics usingcompute_metrics
(defined earlier).
Use JAX's @jit
decorator to trace the entire train_step
function and just-in-time(JIT-compile with XLA into fused device operations that run faster and more efficiently on hardware accelerators.
@jax.jit
def train_step(state, batch):
def loss_fn(params):
logits = CNN().apply({'params': params}, batch['image'])
loss = jnp.mean(optax.softmax_cross_entropy(
logits=logits,
labels=jax.nn.one_hot(batch['label'], num_classes=10)))
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits, batch['label'])
return state, metrics
- Create a
jit
-compiled function that evaluates the model on the test set usingflax.linen.Module.apply
:
@jax.jit
def eval_step(params, batch):
logits = CNN().apply({'params': params}, batch['image'])
return compute_metrics(logits, batch['label'])
- Define a training function for one epoch that:
- Shuffles the training data before each epoch using
jax.random.permutation
that takes a PRNGKey as a parameter (discussed in more detail later in this tutorial and in JAX - the sharp bits). - Runs an optimization step for each batch.
- Retrieves the training metrics from the device with
jax.device_get
and computes their mean across each batch in an epoch. - Returns the optimizer
state
with updated parameters and the training loss and accuracy metrics (training_epoch_metrics
).
def train_epoch(state, train_ds, batch_size, epoch, rng):
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size] # Skip an incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state, metrics = train_step(state, batch)
batch_metrics.append(metrics)
training_batch_metrics = jax.device_get(batch_metrics)
training_epoch_metrics = {
k: np.mean([metrics[k] for metrics in training_batch_metrics])
for k in training_batch_metrics[0]}
print('Training - epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, training_epoch_metrics['loss'], training_epoch_metrics['accuracy'] * 100))
return state, training_epoch_metrics
- Create a model evaluation function that:
- Evalues the model on the test set.
- Retrieves the evaluation metrics from the device with
jax.device_get
. - Copies the metrics data stored in a JAX pytree.
- Returns the test loss and accuracy.
def eval_model(model, test_ds):
metrics = eval_step(model, test_ds)
metrics = jax.device_get(metrics)
eval_summary = jax.tree_map(lambda x: x.item(), metrics)
return eval_summary['loss'], eval_summary['accuracy']
Download the dataset and preprocess it with get_datasets
you defined earlier:
train_ds, test_ds = get_datasets()
- PRNGs: Before you start training the model, you need to randomly initialize the parameters.
In NumPy, you would usually use the stateful pseudorandom number generators (PRNG).
JAX, however, uses an explicit PRNG (refer to JAX - the sharp bits for details):
Note that in JAX and Flax you can have separate PRNG chains (with different names, such as rng
and init_rng
below) inside Module
s for different applications. (Learn more about PRNG chains and JAX PRNG design.)
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
- Instantiate the
CNN
model and initialize its parameters using a PRNG:
cnn = CNN()
params = cnn.init(init_rng, jnp.ones([1, 28, 28, 1]))['params']
- Instantiate the SGD optimizer with Optax:
nesterov_momentum = 0.9
learning_rate = 0.001
tx = optax.sgd(learning_rate=learning_rate, nesterov=nesterov_momentum)
- Create a
TrainState
data class that applies the gradients and updates the optimizer state and parameters.
state = train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)
- Set the default number of epochs and the size of each batch:
num_epochs = 10
batch_size = 32
- Finally, begin training and evaluating the model over 10 epochs:
- For your training function (
train_epoch
), you need to pass a PRNG key used to permute image data during shuffling. Since you have created a PRNG key when initializing the parameters in your nework, you just need to split or "fork" the PRNG state into two (while maintaining the usual desirable PRNG properties) to get a new subkey (input_rng
, in this example) and the previous key (rng
). Usejax.random.split
to carry this out. (Learn more about JAX PRNG design.) - Run an optimization step over a training batch (
train_epoch
). - Evaluate on the test set after each training epoch (
eval_model
). - Retrieve the metrics from the device and print them.
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
# Run an optimization step over a training batch
state, train_metrics = train_epoch(state, train_ds, batch_size, epoch, input_rng)
# Evaluate on the test set after each training epoch
test_loss, test_accuracy = eval_model(state.params, test_ds)
print('Testing - epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))
Training - epoch: 1, loss: 1.7941, accuracy: 62.73
Testing - epoch: 1, loss: 0.93, accuracy: 82.31
Training - epoch: 2, loss: 0.6114, accuracy: 85.10
Testing - epoch: 2, loss: 0.44, accuracy: 88.47
Training - epoch: 3, loss: 0.4128, accuracy: 88.40
Testing - epoch: 3, loss: 0.36, accuracy: 89.89
Training - epoch: 4, loss: 0.3598, accuracy: 89.67
Testing - epoch: 4, loss: 0.32, accuracy: 90.81
Training - epoch: 5, loss: 0.3280, accuracy: 90.50
Testing - epoch: 5, loss: 0.30, accuracy: 91.54
Training - epoch: 6, loss: 0.3047, accuracy: 91.18
Testing - epoch: 6, loss: 0.28, accuracy: 91.94
Training - epoch: 7, loss: 0.2853, accuracy: 91.71
Testing - epoch: 7, loss: 0.26, accuracy: 92.26
Training - epoch: 8, loss: 0.2680, accuracy: 92.15
Testing - epoch: 8, loss: 0.24, accuracy: 92.90
Training - epoch: 9, loss: 0.2522, accuracy: 92.72
Testing - epoch: 9, loss: 0.23, accuracy: 93.15
Training - epoch: 10, loss: 0.2384, accuracy: 92.99
Testing - epoch: 10, loss: 0.22, accuracy: 93.56