Jax

Prerequisites Make sure to read the following sections of the documentation before using this example:

The full source code for this example is available on the mila-docs GitHub repository.

job.sh

 # distributed/single_gpu/job.sh -> frameworks/jax/job.sh
 #!/bin/bash
 #SBATCH --gpus-per-task=rtx8000:1
 #SBATCH --cpus-per-task=4
 #SBATCH --ntasks-per-node=1
 #SBATCH --mem=16G
 #SBATCH --time=00:15:00
 
 
 # Echo time and hostname into log
 echo "Date:     $(date)"
 echo "Hostname: $(hostname)"
 
 
 # Ensure only anaconda/3 module loaded.
 module --quiet purge
 # This example uses Conda to manage package dependencies.
 # See https://docs.mila.quebec/Userguide.html#conda for more information.
 module load anaconda/3
-module load cuda/11.7
 
 # Creating the environment for the first time:
-# conda create -y -n pytorch python=3.9 pytorch torchvision torchaudio \
-#     pytorch-cuda=11.7 -c pytorch -c nvidia
-# Other conda packages:
-# conda install -y -n pytorch -c conda-forge rich tqdm
+# conda create -y -n jax_ex -c "nvidia/label/cuda-11.8.0" cuda python=3.9 virtualenv pip
+# conda activate jax_ex
+# Install Jax using `pip`
+# *Please note* that as soon as you install packages from `pip install`, you
+# should not install any more packages using `conda install`
+# pip install --upgrade "jax[cuda11_pip]" \
+#    -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+# Other pip packages:
+# pip install pillow optax rich torch torchvision flax tqdm
 
-# Activate pre-existing environment.
-conda activate pytorch
+# Activate the environment:
+conda activate jax_ex
 
 
 # Stage dataset into $SLURM_TMPDIR
 mkdir -p $SLURM_TMPDIR/data
 cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
 # General-purpose alternatives combining copy and unpack:
 #     unzip   /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
 #     tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/
 
 
-# Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0
-unset CUDA_VISIBLE_DEVICES
-
 # Execute Python script
 python main.py

main.py

 # distributed/single_gpu/main.py -> frameworks/jax/main.py
-"""Single-GPU training example."""
+"""Single-GPU training example.
+
+This Jax example is heavily based on the following examples:
+
+* https://juliusruseckas.github.io/ml/flax-cifar10.html
+* https://github.com/fattorib/Flax-ResNets/blob/master/main_flax.py
+"""
 import argparse
 import logging
+import math
 import os
 from pathlib import Path
+from typing import Any, Sequence
 
+import PIL.Image
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
 import rich.logging
 import torch
-from torch import Tensor, nn
-from torch.nn import functional as F
+
+from flax.training import train_state, common_utils
 from torch.utils.data import DataLoader, random_split
-from torchvision import transforms
 from torchvision.datasets import CIFAR10
-from torchvision.models import resnet18
 from tqdm import tqdm
 
+from model import ResNet
+
+
+class TrainState(train_state.TrainState):
+    batch_stats: Any
+
+
+class ToArray(torch.nn.Module):
+    """convert image to float and 0-1 range"""
+    dtype = np.float32
+
+    def __call__(self, x):
+        assert isinstance(x, PIL.Image.Image)
+        x = np.asarray(x, dtype=self.dtype)
+        x /= 255.0
+        return x
+
+
+def numpy_collate(batch: Sequence):
+    if isinstance(batch[0], np.ndarray):
+        return np.stack(batch)
+    elif isinstance(batch[0], (tuple, list)):
+        transposed = zip(*batch)
+        return [numpy_collate(samples) for samples in transposed]
+    else:
+        return np.array(batch)
+
 
 def main():
     # Use an argument parser so we can pass hyperparameters from the command line.
     parser = argparse.ArgumentParser(description=__doc__)
     parser.add_argument("--epochs", type=int, default=10)
     parser.add_argument("--learning-rate", type=float, default=5e-4)
     parser.add_argument("--weight-decay", type=float, default=1e-4)
     parser.add_argument("--batch-size", type=int, default=128)
     args = parser.parse_args()
 
     epochs: int = args.epochs
     learning_rate: float = args.learning_rate
     weight_decay: float = args.weight_decay
+    # NOTE: This is the "local" batch size, per-GPU.
     batch_size: int = args.batch_size
 
     # Check that the GPU is available
     assert torch.cuda.is_available() and torch.cuda.device_count() > 0
-    device = torch.device("cuda", 0)
+    rng = jax.random.PRNGKey(0)
 
     # Setup logging (optional, but much better than using print statements)
     logging.basicConfig(
         level=logging.INFO,
         handlers=[rich.logging.RichHandler(markup=True)],  # Very pretty, uses the `rich` package.
     )
 
     logger = logging.getLogger(__name__)
 
-    # Create a model and move it to the GPU.
-    model = resnet18(num_classes=10)
-    model.to(device=device)
+    # Create a model.
+    model = ResNet(
+        10,
+        channel_list = [64, 128, 256, 512],
+        num_blocks_list = [2, 2, 2, 2],
+        strides = [1, 1, 2, 2, 2],
+        head_p_drop = 0.3
+    )
 
-    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
+    @jax.jit
+    def initialize(params_rng, image_size=32):
+        init_rngs = {'params': params_rng}
+        input_shape = (1, image_size, image_size, 3)
+        variables = model.init(init_rngs, jnp.ones(input_shape, jnp.float32), train=False)
+        return variables
 
     # Setup CIFAR10
     num_workers = get_num_workers()
     dataset_path = Path(os.environ.get("SLURM_TMPDIR", ".")) / "data"
     train_dataset, valid_dataset, test_dataset = make_datasets(str(dataset_path))
     train_dataloader = DataLoader(
         train_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=True,
+        collate_fn=numpy_collate,
     )
     valid_dataloader = DataLoader(
         valid_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=False,
+        collate_fn=numpy_collate,
     )
     test_dataloader = DataLoader(  # NOTE: Not used in this example.
         test_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=False,
+        collate_fn=numpy_collate,
+    )
+
+    train_steps_per_epoch = math.ceil(len(train_dataset) / batch_size)
+    num_train_steps = train_steps_per_epoch * epochs
+    shedule_fn = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=learning_rate)
+    optimizer = optax.adamw(learning_rate=shedule_fn, weight_decay=weight_decay)
+
+    params_rng, dropout_rng = jax.random.split(rng)
+    variables = initialize(params_rng)
+
+    state = TrainState.create(
+        apply_fn = model.apply,
+        params = variables['params'],
+        batch_stats = variables['batch_stats'],
+        tx = optimizer
     )
 
     # Checkout the "checkpointing and preemption" example for more info!
     logger.debug("Starting training from scratch.")
 
     for epoch in range(epochs):
         logger.debug(f"Starting epoch {epoch}/{epochs}")
 
-        # Set the model in training mode (important for e.g. BatchNorm and Dropout layers)
-        model.train()
-
         # NOTE: using a progress bar from tqdm because it's nicer than using `print`.
         progress_bar = tqdm(
             total=len(train_dataloader),
             desc=f"Train epoch {epoch}",
         )
 
         # Training loop
-        for batch in train_dataloader:
-            # Move the batch to the GPU before we pass it to the model
-            batch = tuple(item.to(device) for item in batch)
-            x, y = batch
-
-            # Forward pass
-            logits: Tensor = model(x)
-
-            loss = F.cross_entropy(logits, y)
+        for input, target in train_dataloader:
+            batch = {
+                'image': input,
+                'label': target,
+            }
+            state, loss, accuracy = train_step(state, batch, dropout_rng)
 
-            optimizer.zero_grad()
-            loss.backward()
-            optimizer.step()
+            logger.debug(f"Accuracy: {accuracy:.2%}")
+            logger.debug(f"Average Loss: {loss}")
 
-            # Calculate some metrics:
-            n_correct_predictions = logits.detach().argmax(-1).eq(y).sum()
-            n_samples = y.shape[0]
-            accuracy = n_correct_predictions / n_samples
-
-            logger.debug(f"Accuracy: {accuracy.item():.2%}")
-            logger.debug(f"Average Loss: {loss.item()}")
-
-            # Advance the progress bar one step and update the progress bar text.
+            # Advance the progress bar one step, and update the "postfix" () the progress bar. (nicer than just)
             progress_bar.update(1)
-            progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
+            progress_bar.set_postfix(loss=loss, accuracy=accuracy)
         progress_bar.close()
 
-        val_loss, val_accuracy = validation_loop(model, valid_dataloader, device)
+        val_loss, val_accuracy = validation_loop(state, valid_dataloader)
         logger.info(f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}")
 
     print("Done!")
 
 
-@torch.no_grad()
-def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device):
-    model.eval()
+def cross_entropy_loss(logits, labels, num_classes=10):
+    one_hot_labels = common_utils.onehot(labels, num_classes=num_classes)
+    loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
+    loss = jnp.mean(loss)
+    return loss
 
-    total_loss = 0.0
-    n_samples = 0
-    correct_predictions = 0
 
-    for batch in dataloader:
-        batch = tuple(item.to(device) for item in batch)
-        x, y = batch
+@jax.jit
+def train_step(state, batch, dropout_rng):
+    dropout_rng = jax.random.fold_in(dropout_rng, state.step)
 
-        logits: Tensor = model(x)
-        loss = F.cross_entropy(logits, y)
+    def loss_fn(params):
+        variables = {'params': params, 'batch_stats': state.batch_stats}
+        logits, new_model_state = state.apply_fn(variables, batch['image'], train=True,
+                                                 rngs={'dropout': dropout_rng}, mutable='batch_stats')
+        loss = cross_entropy_loss(logits, batch['label'])
+        accuracy = jnp.sum(jnp.argmax(logits, -1) == batch['label'])
+        return loss, (accuracy, new_model_state)
 
-        batch_n_samples = x.shape[0]
-        batch_correct_predictions = logits.argmax(-1).eq(y).sum()
+    (loss, (accuracy, new_model_state)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
+    new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
+    return new_state, loss, accuracy
 
-        total_loss += loss.item()
-        n_samples += batch_n_samples
-        correct_predictions += batch_correct_predictions
 
-    accuracy = correct_predictions / n_samples
+@jax.jit
+def validation_step(state, batch):
+    variables = {'params': state.params, 'batch_stats': state.batch_stats}
+    logits = state.apply_fn(variables, batch['image'], train=False, mutable=False)
+    loss = cross_entropy_loss(logits, batch['label'])
+    batch_correct_predictions = jnp.sum(jnp.argmax(logits, -1) == batch['label'])
+    return loss, batch_correct_predictions
+
+
+@torch.no_grad()
+def validation_loop(state, dataloader: DataLoader):
+    losses = []
+    correct_predictions = []
+    for input, target in dataloader:
+        batch = {
+            'image': input,
+            'label': target,
+        }
+        loss, batch_correct_predictions = validation_step(state, batch)
+        losses.append(loss)
+        correct_predictions.append(batch_correct_predictions)
+
+    total_loss = np.sum(losses)
+    accuracy = np.mean(correct_predictions)
     return total_loss, accuracy
 
 
 def make_datasets(
     dataset_path: str,
     val_split: float = 0.1,
     val_split_seed: int = 42,
 ):
     """Returns the training, validation, and test splits for CIFAR10.
 
     NOTE: We don't use image transforms here for simplicity.
     Having different transformations for train and validation would complicate things a bit.
     Later examples will show how to do the train/val/test split properly when using transforms.
     """
     train_dataset = CIFAR10(
-        root=dataset_path, transform=transforms.ToTensor(), download=True, train=True
+        root=dataset_path, transform=ToArray(), download=True, train=True
     )
     test_dataset = CIFAR10(
-        root=dataset_path, transform=transforms.ToTensor(), download=True, train=False
+        root=dataset_path, transform=ToArray(), download=True, train=False
     )
     # Split the training dataset into a training and validation set.
     n_samples = len(train_dataset)
     n_valid = int(val_split * n_samples)
     n_train = n_samples - n_valid
     train_dataset, valid_dataset = random_split(
         train_dataset, (n_train, n_valid), torch.Generator().manual_seed(val_split_seed)
     )
     return train_dataset, valid_dataset, test_dataset
 
 
 def get_num_workers() -> int:
     """Gets the optimal number of DatLoader workers to use in the current job."""
     if "SLURM_CPUS_PER_TASK" in os.environ:
         return int(os.environ["SLURM_CPUS_PER_TASK"])
     if hasattr(os, "sched_getaffinity"):
         return len(os.sched_getaffinity(0))
     return torch.multiprocessing.cpu_count()
 
 
 if __name__ == "__main__":
     main()

model.py

from functools import partial
from typing import Any, Sequence

import jax.numpy as jnp

from flax import linen as nn


ModuleDef = Any


class ConvBlock(nn.Module):
    channels: int
    kernel_size: int
    norm: ModuleDef
    stride: int = 1
    act: bool = True

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(self.channels, (self.kernel_size, self.kernel_size), strides=self.stride,
                    padding='SAME', use_bias=False, kernel_init=nn.initializers.kaiming_normal())(x)
        x = self.norm()(x)
        if self.act:
            x = nn.swish(x)
        return x


class ResidualBlock(nn.Module):
    channels: int
    conv_block: ModuleDef

    @nn.compact
    def __call__(self, x):
        channels = self.channels
        conv_block = self.conv_block

        shortcut = x

        residual = conv_block(channels, 3)(x)
        residual = conv_block(channels, 3, act=False)(residual)

        if shortcut.shape != residual.shape:
            shortcut = conv_block(channels, 1, act=False)(shortcut)

        gamma = self.param('gamma', nn.initializers.zeros, 1, jnp.float32)
        out = shortcut + gamma * residual
        out = nn.swish(out)
        return out


class Stage(nn.Module):
    channels: int
    num_blocks: int
    stride: int
    block: ModuleDef

    @nn.compact
    def __call__(self, x):
        stride = self.stride
        if stride > 1:
            x = nn.max_pool(x, (stride, stride), strides=(stride, stride))
        for _ in range(self.num_blocks):
            x = self.block(self.channels)(x)
        return x


class Body(nn.Module):
    channel_list: Sequence[int]
    num_blocks_list: Sequence[int]
    strides: Sequence[int]
    stage: ModuleDef

    @nn.compact
    def __call__(self, x):
        for channels, num_blocks, stride in zip(self.channel_list, self.num_blocks_list, self.strides):
            x = self.stage(channels, num_blocks, stride)(x)
        return x


class Stem(nn.Module):
    channel_list: Sequence[int]
    stride: int
    conv_block: ModuleDef

    @nn.compact
    def __call__(self, x):
        stride = self.stride
        for channels in self.channel_list:
            x = self.conv_block(channels, 3, stride=stride)(x)
            stride = 1
        return x


class Head(nn.Module):
    classes: int
    dropout: ModuleDef

    @nn.compact
    def __call__(self, x):
        x = jnp.mean(x, axis=(1, 2))
        x = self.dropout()(x)
        x = nn.Dense(self.classes)(x)
        return x


class ResNet(nn.Module):
    classes: int
    channel_list: Sequence[int]
    num_blocks_list: Sequence[int]
    strides: Sequence[int]
    head_p_drop: float = 0.

    @nn.compact
    def __call__(self, x, train=True):
        norm = partial(nn.BatchNorm, use_running_average=not train)
        dropout = partial(nn.Dropout, rate=self.head_p_drop, deterministic=not train)
        conv_block = partial(ConvBlock, norm=norm)
        residual_block = partial(ResidualBlock, conv_block=conv_block)
        stage = partial(Stage, block=residual_block)

        x = Stem([32, 32, 64], self.strides[0], conv_block)(x)
        x = Body(self.channel_list, self.num_blocks_list, self.strides[1:], stage)(x)
        x = Head(self.classes, dropout)(x)
        return x

Running this example

$ sbatch job.sh