Jax
Prerequisites Make sure to read the following sections of the documentation before using this example:
The full source code for this example is available on the mila-docs GitHub repository.
job.sh
#!/bin/bash
#SBATCH --gpus-per-task=rtx8000:1
#SBATCH --cpus-per-task=4
#SBATCH --ntasks-per-node=1
#SBATCH --mem=16G
#SBATCH --time=00:15:00
set -e # exit on error.
echo "Date: $(date)"
echo "Hostname: $(hostname)"
# Stage dataset into $SLURM_TMPDIR
mkdir -p $SLURM_TMPDIR/data
cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
# General-purpose alternatives combining copy and unpack:
# unzip /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
# tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/
# Execute Python script
# Use `uv run --offline` on clusters without internet access on compute nodes.
uv run python main.py
main.py
"""Single-GPU training example.
This Jax example is heavily based on the following examples:
* https://juliusruseckas.github.io/ml/flax-cifar10.html
* https://github.com/fattorib/Flax-ResNets/blob/master/main_flax.py
"""
import argparse
import logging
import math
import os
from pathlib import Path
import sys
from typing import Any, Sequence
import PIL.Image
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rich.logging
import torch
from flax.training import train_state, common_utils
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from model import ResNet
class TrainState(train_state.TrainState):
batch_stats: Any
class ToArray(torch.nn.Module):
"""convert image to float and 0-1 range"""
dtype = np.float32
def __call__(self, x):
assert isinstance(x, PIL.Image.Image)
x = np.asarray(x, dtype=self.dtype)
x /= 255.0
return x
def numpy_collate(batch: Sequence):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple, list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)
def main():
# Use an argument parser so we can pass hyperparameters from the command line.
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--learning-rate", type=float, default=5e-4)
parser.add_argument("--weight-decay", type=float, default=1e-4)
parser.add_argument("--batch-size", type=int, default=128)
args = parser.parse_args()
epochs: int = args.epochs
learning_rate: float = args.learning_rate
weight_decay: float = args.weight_decay
# NOTE: This is the "local" batch size, per-GPU.
batch_size: int = args.batch_size
# Check that the GPU is available
assert torch.cuda.is_available() and torch.cuda.device_count() > 0
rng = jax.random.PRNGKey(0)
# Setup logging (optional, but much better than using print statements)
# Uses the `rich` package to make logs pretty.
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
handlers=[
rich.logging.RichHandler(
markup=True,
console=rich.console.Console(
# Allower wider log lines in sbatch output files than on the terminal.
width=120 if not sys.stdout.isatty() else None
),
)
],
)
logger = logging.getLogger(__name__)
# Create a model.
model = ResNet(
10,
channel_list=[64, 128, 256, 512],
num_blocks_list=[2, 2, 2, 2],
strides=[1, 1, 2, 2, 2],
head_p_drop=0.3,
)
@jax.jit
def initialize(params_rng, image_size=32):
init_rngs = {"params": params_rng}
input_shape = (1, image_size, image_size, 3)
variables = model.init(
init_rngs, jnp.ones(input_shape, jnp.float32), train=False
)
return variables
# Setup CIFAR10
num_workers = get_num_workers()
dataset_path = Path(os.environ.get("SLURM_TMPDIR", ".")) / "data"
train_dataset, valid_dataset, test_dataset = make_datasets(str(dataset_path))
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
collate_fn=numpy_collate,
)
valid_dataloader = DataLoader(
valid_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False,
collate_fn=numpy_collate,
)
test_dataloader = DataLoader( # NOTE: Not used in this example.
test_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False,
collate_fn=numpy_collate,
)
train_steps_per_epoch = math.ceil(len(train_dataset) / batch_size)
num_train_steps = train_steps_per_epoch * epochs
shedule_fn = optax.cosine_onecycle_schedule(
transition_steps=num_train_steps, peak_value=learning_rate
)
optimizer = optax.adamw(learning_rate=shedule_fn, weight_decay=weight_decay)
params_rng, dropout_rng = jax.random.split(rng)
variables = initialize(params_rng)
state = TrainState.create(
apply_fn=model.apply,
params=variables["params"],
batch_stats=variables["batch_stats"],
tx=optimizer,
)
# Checkout the "checkpointing and preemption" example for more info!
logger.debug("Starting training from scratch.")
for epoch in range(epochs):
logger.debug(f"Starting epoch {epoch}/{epochs}")
# NOTE: using a progress bar from tqdm because it's nicer than using `print`.
progress_bar = tqdm(
total=len(train_dataloader),
desc=f"Train epoch {epoch}",
disable=not sys.stdout.isatty(), # Disable progress bar in non-interactive environments.
)
# Training loop
for input, target in train_dataloader:
batch = {
"image": input,
"label": target,
}
state, loss, accuracy = train_step(state, batch, dropout_rng)
logger.debug(f"Accuracy: {accuracy:.2%}")
logger.debug(f"Average Loss: {loss}")
# Advance the progress bar one step, and update the "postfix" () the progress bar. (nicer than just)
progress_bar.update(1)
progress_bar.set_postfix(loss=loss, accuracy=accuracy)
progress_bar.close()
val_loss, val_accuracy = validation_loop(state, valid_dataloader)
logger.info(
f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}"
)
print("Done!")
def cross_entropy_loss(logits, labels, num_classes=10):
one_hot_labels = common_utils.onehot(labels, num_classes=num_classes)
loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
loss = jnp.mean(loss)
return loss
@jax.jit
def train_step(state, batch, dropout_rng):
dropout_rng = jax.random.fold_in(dropout_rng, state.step)
def loss_fn(params):
variables = {"params": params, "batch_stats": state.batch_stats}
logits, new_model_state = state.apply_fn(
variables,
batch["image"],
train=True,
rngs={"dropout": dropout_rng},
mutable="batch_stats",
)
loss = cross_entropy_loss(logits, batch["label"])
accuracy = jnp.sum(jnp.argmax(logits, -1) == batch["label"])
return loss, (accuracy, new_model_state)
(loss, (accuracy, new_model_state)), grads = jax.value_and_grad(
loss_fn, has_aux=True
)(state.params)
new_state = state.apply_gradients(
grads=grads, batch_stats=new_model_state["batch_stats"]
)
return new_state, loss, accuracy
@jax.jit
def validation_step(state, batch):
variables = {"params": state.params, "batch_stats": state.batch_stats}
logits = state.apply_fn(variables, batch["image"], train=False, mutable=False)
loss = cross_entropy_loss(logits, batch["label"])
batch_correct_predictions = jnp.sum(jnp.argmax(logits, -1) == batch["label"])
return loss, batch_correct_predictions
@torch.no_grad()
def validation_loop(state, dataloader: DataLoader):
losses = []
correct_predictions = []
for input, target in dataloader:
batch = {
"image": input,
"label": target,
}
loss, batch_correct_predictions = validation_step(state, batch)
losses.append(loss)
correct_predictions.append(batch_correct_predictions)
total_loss = np.sum(losses)
accuracy = np.mean(correct_predictions)
return total_loss, accuracy
def make_datasets(
dataset_path: str,
val_split: float = 0.1,
val_split_seed: int = 42,
):
"""Returns the training, validation, and test splits for CIFAR10.
NOTE: We don't use image transforms here for simplicity.
Having different transformations for train and validation would complicate things a bit.
Later examples will show how to do the train/val/test split properly when using transforms.
"""
train_dataset = CIFAR10(
root=dataset_path, transform=ToArray(), download=True, train=True
)
test_dataset = CIFAR10(
root=dataset_path, transform=ToArray(), download=True, train=False
)
# Split the training dataset into a training and validation set.
n_samples = len(train_dataset)
n_valid = int(val_split * n_samples)
n_train = n_samples - n_valid
train_dataset, valid_dataset = random_split(
train_dataset, (n_train, n_valid), torch.Generator().manual_seed(val_split_seed)
)
return train_dataset, valid_dataset, test_dataset
def get_num_workers() -> int:
"""Gets the optimal number of DatLoader workers to use in the current job."""
if "SLURM_CPUS_PER_TASK" in os.environ:
return int(os.environ["SLURM_CPUS_PER_TASK"])
if hasattr(os, "sched_getaffinity"):
return len(os.sched_getaffinity(0))
return torch.multiprocessing.cpu_count()
if __name__ == "__main__":
main()
model.py
from functools import partial
from typing import Any, Sequence
import jax.numpy as jnp
from flax import linen as nn
ModuleDef = Any
class ConvBlock(nn.Module):
channels: int
kernel_size: int
norm: ModuleDef
stride: int = 1
act: bool = True
@nn.compact
def __call__(self, x):
x = nn.Conv(
self.channels,
(self.kernel_size, self.kernel_size),
strides=self.stride,
padding="SAME",
use_bias=False,
kernel_init=nn.initializers.kaiming_normal(),
)(x)
x = self.norm()(x)
if self.act:
x = nn.swish(x)
return x
class ResidualBlock(nn.Module):
channels: int
conv_block: ModuleDef
@nn.compact
def __call__(self, x):
channels = self.channels
conv_block = self.conv_block
shortcut = x
residual = conv_block(channels, 3)(x)
residual = conv_block(channels, 3, act=False)(residual)
if shortcut.shape != residual.shape:
shortcut = conv_block(channels, 1, act=False)(shortcut)
gamma = self.param("gamma", nn.initializers.zeros, 1, jnp.float32)
out = shortcut + gamma * residual
out = nn.swish(out)
return out
class Stage(nn.Module):
channels: int
num_blocks: int
stride: int
block: ModuleDef
@nn.compact
def __call__(self, x):
stride = self.stride
if stride > 1:
x = nn.max_pool(x, (stride, stride), strides=(stride, stride))
for _ in range(self.num_blocks):
x = self.block(self.channels)(x)
return x
class Body(nn.Module):
channel_list: Sequence[int]
num_blocks_list: Sequence[int]
strides: Sequence[int]
stage: ModuleDef
@nn.compact
def __call__(self, x):
for channels, num_blocks, stride in zip(
self.channel_list, self.num_blocks_list, self.strides
):
x = self.stage(channels, num_blocks, stride)(x)
return x
class Stem(nn.Module):
channel_list: Sequence[int]
stride: int
conv_block: ModuleDef
@nn.compact
def __call__(self, x):
stride = self.stride
for channels in self.channel_list:
x = self.conv_block(channels, 3, stride=stride)(x)
stride = 1
return x
class Head(nn.Module):
classes: int
dropout: ModuleDef
@nn.compact
def __call__(self, x):
x = jnp.mean(x, axis=(1, 2))
x = self.dropout()(x)
x = nn.Dense(self.classes)(x)
return x
class ResNet(nn.Module):
classes: int
channel_list: Sequence[int]
num_blocks_list: Sequence[int]
strides: Sequence[int]
head_p_drop: float = 0.0
@nn.compact
def __call__(self, x, train=True):
norm = partial(nn.BatchNorm, use_running_average=not train)
dropout = partial(nn.Dropout, rate=self.head_p_drop, deterministic=not train)
conv_block = partial(ConvBlock, norm=norm)
residual_block = partial(ResidualBlock, conv_block=conv_block)
stage = partial(Stage, block=residual_block)
x = Stem([32, 32, 64], self.strides[0], conv_block)(x)
x = Body(self.channel_list, self.num_blocks_list, self.strides[1:], stage)(x)
x = Head(self.classes, dropout)(x)
return x
Running this example
$ sbatch job.sh