Checkpointing

Prerequisites

Make sure to read the following sections of the documentation before using this example:

The full source code for this example is available on the mila-docs GitHub repository.

job.sh

 # distributed/single_gpu/job.sh -> good_practices/checkpointing/job.sh
 #!/bin/bash
 #SBATCH --ntasks=1
 #SBATCH --ntasks-per-node=1
 #SBATCH --cpus-per-task=4
 #SBATCH --gpus-per-task=l40s:1
 #SBATCH --mem-per-gpu=16G
 #SBATCH --time=00:15:00
+#SBATCH --requeue
+#SBATCH --signal=B:TERM@300 # tells the controller to send SIGTERM to the job 5
+                            # min before its time ends to give it a chance for
+                            # better cleanup. If you cancel the job manually,
+                            # make sure that you specify the signal as TERM like
+                            # so `scancel --signal=TERM <jobid>`.
+                            # https://dhruveshp.com/blog/2021/signal-propagation-on-slurm/
 
 # Exit on error
 set -e
 
 # Echo time and hostname into log
 echo "Date:     $(date)"
 echo "Hostname: $(hostname)"
+echo "Job has been preempted $SLURM_RESTART_COUNT times."
 
 # To make your code as much reproducible as possible with
 # `torch.use_deterministic_algorithms(True)`, uncomment the following block:
 ## === Reproducibility ===
 ## Be warned that this can make your code slower. See
 ## https://pytorch.org/docs/stable/notes/randomness.html#cublas-and-cudnn-deterministic-operations
 ## for more details.
 # export CUBLAS_WORKSPACE_CONFIG=:4096:8
 ## === Reproducibility (END) ===
 
 # Stage dataset into $SLURM_TMPDIR
 mkdir -p $SLURM_TMPDIR/data
 cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
 # General-purpose alternatives combining copy and unpack:
 #     unzip   /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
 #     tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/
 
 # Execute Python script
 # Use the `--offline` option of `uv run` on clusters without internet access on compute nodes.
 # Using the `--locked` option can help make your experiments easier to reproduce (it forces
 # your uv.lock file to be up to date with the dependencies declared in pyproject.toml).
-srun uv run python main.py
+# Here we use `exec` to ensure that the signals are received and handled in the Python process.
+exec srun uv run python main.py

pyproject.toml

[project]
name = "checkpointing-example"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11,<3.14"
dependencies = [
    "rich>=14.0.0",
    "torch>=2.7.1",
    "torchvision>=0.22.1",
    "tqdm>=4.67.1",
]

main.py

 # distributed/single_gpu/main.py -> good_practices/checkpointing/main.py
-"""Single-GPU training example."""
+"""Checkpointing example."""
+
+from __future__ import annotations
 
 import argparse
 import logging
 import os
 import random
+import shutil
+import signal
 import sys
+import uuid
+import warnings
+from logging import getLogger as get_logger
 from pathlib import Path
+from types import FrameType
+from typing import Any, TypedDict
 
 import numpy as np
 import rich.logging
 import torch
 from torch import Tensor, nn
 from torch.nn import functional as F
 from torch.utils.data import DataLoader, random_split
 from torchvision import transforms
 from torchvision.datasets import CIFAR10
 from torchvision.models import resnet18
 from tqdm import tqdm
 
+SCRATCH = Path(os.environ["SCRATCH"])
+SLURM_TMPDIR = Path(os.environ["SLURM_TMPDIR"])
+SLURM_JOBID = os.environ["SLURM_JOBID"]
+
+CHECKPOINT_FILE_NAME = "checkpoint.pth"
+
 
 # To make your code as much reproducible as possible, uncomment the following
 # block:
 ## === Reproducibility ===
 ## Be warned that this can make your code slower. See
 ## https://pytorch.org/docs/stable/notes/randomness.html#cublas-and-cudnn-deterministic-operations
 ## for more details.
 # torch.use_deterministic_algorithms(True)
 ## === Reproducibility (END) ===
 
 
+class RunState(TypedDict):
+    """Typed dictionary containing the state of the training run which is saved at each epoch.
+
+    Using type hints helps prevent bugs and makes your code easier to read for both humans and
+    machines (e.g. Copilot). This leads to less time spent debugging and better code suggestions.
+    """
+
+    epoch: int
+    best_acc: float
+    model_state: dict[str, Tensor]
+    optimizer_state: dict[str, Tensor]
+
+    random_state: tuple[Any, ...]
+    numpy_random_state: dict[str, Any]
+    torch_random_state: Tensor
+    torch_cuda_random_state: list[Tensor]
+
+
 def main():
     # Use an argument parser so we can pass hyperparameters from the command line.
     parser = argparse.ArgumentParser(description=__doc__)
     parser.add_argument("--epochs", type=int, default=10)
     parser.add_argument("--learning-rate", type=float, default=5e-4)
     parser.add_argument("--weight-decay", type=float, default=1e-4)
     parser.add_argument("--batch-size", type=int, default=128)
+    parser.add_argument(
+        "--run-dir", type=Path, default=SCRATCH / "checkpointing_example" / SLURM_JOBID
+    )
     parser.add_argument("--seed", type=int, default=42)
     args = parser.parse_args()
 
     epochs: int = args.epochs
     learning_rate: float = args.learning_rate
     weight_decay: float = args.weight_decay
     batch_size: int = args.batch_size
+    run_dir: Path = args.run_dir
     seed: int = args.seed
 
+    checkpoint_dir = run_dir / "checkpoints"
+    start_epoch: int = 0
+    best_acc: float = 0.0
+
     # Seed the random number generators as early as possible for reproducibility
     random.seed(seed)
     np.random.seed(seed)
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
 
     # Check that the GPU is available
     assert torch.cuda.is_available() and torch.cuda.device_count() > 0
     device = torch.device("cuda", 0)
 
     # Setup logging (optional, but much better than using print statements)
     # Uses the `rich` package to make logs pretty.
     logging.basicConfig(
         level=logging.INFO,
         format="%(message)s",
         handlers=[
             rich.logging.RichHandler(
                 markup=True,
                 console=rich.console.Console(
                     # Allower wider log lines in sbatch output files than on the terminal.
                     width=120 if not sys.stdout.isatty() else None
                 ),
             )
         ],
     )
 
     logger = logging.getLogger(__name__)
 
-    # Create a model and move it to the GPU.
+    # Create a model.
     model = resnet18(num_classes=10)
+
+    # Move the model to the GPU.
     model.to(device=device)
 
     optimizer = torch.optim.AdamW(
         model.parameters(), lr=learning_rate, weight_decay=weight_decay
     )
 
-    # Setup CIFAR10
+    # Try to resume from a checkpoint, if one exists.
+    checkpoint: RunState | None = load_checkpoint(checkpoint_dir, map_location=device)
+    if checkpoint:
+        start_epoch = checkpoint["epoch"] + 1  # +1 to start at the next epoch.
+        best_acc = checkpoint["best_acc"]
+        model.load_state_dict(checkpoint["model_state"])
+        optimizer.load_state_dict(checkpoint["optimizer_state"])
+        random.setstate(checkpoint["random_state"])
+        np.random.set_state(checkpoint["numpy_random_state"])
+        # NOTE: Need to move those tensors to CPU before they can be loaded.
+        torch.random.set_rng_state(checkpoint["torch_random_state"].cpu())
+        torch.cuda.random.set_rng_state_all(
+            t.cpu() for t in checkpoint["torch_cuda_random_state"]
+        )
+        logger.info(
+            f"Resuming training at epoch {start_epoch} (best_acc={best_acc:.2%})."
+        )
+    else:
+        logger.info(f"No checkpoints found in {checkpoint_dir}. Training from scratch.")
+
+    # Setup the dataset
     num_workers = get_num_workers()
-    dataset_path = Path(os.environ.get("SLURM_TMPDIR", ".")) / "data"
+    dataset_path = (SLURM_TMPDIR or Path("..")) / "data"
+
     train_dataset, valid_dataset, test_dataset = make_datasets(str(dataset_path))
     train_dataloader = DataLoader(
         train_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=True,
+        # generator=torch.Generator().manual_seed(seed),
     )
     valid_dataloader = DataLoader(
         valid_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=False,
+        # generator=torch.Generator().manual_seed(seed),
     )
     _test_dataloader = DataLoader(  # NOTE: Not used in this example.
         test_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=False,
     )
 
-    # Checkout the "checkpointing and preemption" example for more info!
-    logger.debug("Starting training from scratch.")
-
-    for epoch in range(epochs):
+    def signal_handler(signum: int, frame: FrameType | None):
+        """Called before the job gets pre-empted or reaches the time-limit.
+
+        This should run quickly. Performing a full checkpoint here mid-epoch is not recommended.
+        """
+        signal_enum = signal.Signals(signum)
+        logger.error(f"Job received a {signal_enum.name} signal!")
+        # Perform quick actions that will help the job resume later.
+        # If you use Weights & Biases: https://docs.wandb.ai/guides/runs/resuming#preemptible-sweeps
+        # if wandb.run:
+        #     wandb.mark_preempting()
+
+    signal.signal(
+        signal.SIGTERM, signal_handler
+    )  # Before getting pre-empted and requeued.
+    signal.signal(
+        signal.SIGUSR1, signal_handler
+    )  # Before reaching the end of the time limit.
+
+    for epoch in range(start_epoch, epochs):
         logger.debug(f"Starting epoch {epoch}/{epochs}")
 
-        # Set the model in training mode (important for e.g. BatchNorm and Dropout layers)
+        # Set the model in training mode (this is important for e.g. BatchNorm and Dropout layers)
         model.train()
 
-        # NOTE: using a progress bar from tqdm because it's nicer than using `print`.
+        # NOTE: using a progress bar from tqdm much nicer than using `print`s).
         progress_bar = tqdm(
             total=len(train_dataloader),
             desc=f"Train epoch {epoch}",
             disable=not sys.stdout.isatty(),  # Disable progress bar in non-interactive environments.
         )
 
         # Training loop
+        batch: tuple[Tensor, Tensor]
         for batch in train_dataloader:
             # Move the batch to the GPU before we pass it to the model
             batch = tuple(item.to(device) for item in batch)
             x, y = batch
 
             # Forward pass
             logits: Tensor = model(x)
 
             loss = F.cross_entropy(logits, y)
 
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
 
             # Calculate some metrics:
             n_correct_predictions = logits.detach().argmax(-1).eq(y).sum()
             n_samples = y.shape[0]
             accuracy = n_correct_predictions / n_samples
 
             logger.debug(f"Accuracy: {accuracy.item():.2%}")
             logger.debug(f"Average Loss: {loss.item()}")
 
-            # Advance the progress bar one step and update the progress bar text.
+            # Advance the progress bar one step, and update the text displayed in the progress bar.
             progress_bar.update(1)
             progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
         progress_bar.close()
 
         val_loss, val_accuracy = validation_loop(model, valid_dataloader, device)
         logger.info(
             f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}"
         )
 
+        # remember best accuracy and save the current state.
+        is_best = val_accuracy > best_acc
+        best_acc = max(val_accuracy, best_acc)
+
+        if checkpoint_dir is not None:
+            save_checkpoint(
+                checkpoint_dir,
+                is_best,
+                RunState(
+                    epoch=epoch,
+                    model_state=model.state_dict(),
+                    optimizer_state=optimizer.state_dict(),
+                    random_state=random.getstate(),
+                    numpy_random_state=np.random.get_state(legacy=False),
+                    torch_random_state=torch.random.get_rng_state(),
+                    torch_cuda_random_state=torch.cuda.random.get_rng_state_all(),
+                    best_acc=best_acc,
+                ),
+            )
+
     print("Done!")
 
 
 @torch.no_grad()
 def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device):
     model.eval()
 
     total_loss = 0.0
     n_samples = 0
     correct_predictions = 0
 
     for batch in dataloader:
         batch = tuple(item.to(device) for item in batch)
         x, y = batch
 
         logits: Tensor = model(x)
         loss = F.cross_entropy(logits, y)
 
         batch_n_samples = x.shape[0]
-        batch_correct_predictions = logits.argmax(-1).eq(y).sum()
+        batch_correct_predictions = logits.argmax(-1).eq(y).sum().item()
 
         total_loss += loss.item()
         n_samples += batch_n_samples
-        correct_predictions += batch_correct_predictions
+        correct_predictions += int(batch_correct_predictions)
 
     accuracy = correct_predictions / n_samples
     return total_loss, accuracy
 
 
 def make_datasets(
     dataset_path: str,
     val_split: float = 0.1,
     val_split_seed: int = 42,
 ):
     """Returns the training, validation, and test splits for CIFAR10.
 
     NOTE: We don't use image transforms here for simplicity.
     Having different transformations for train and validation would complicate things a bit.
     Later examples will show how to do the train/val/test split properly when using transforms.
     """
     train_dataset = CIFAR10(
         root=dataset_path, transform=transforms.ToTensor(), download=True, train=True
     )
     test_dataset = CIFAR10(
         root=dataset_path, transform=transforms.ToTensor(), download=True, train=False
     )
     # Split the training dataset into a training and validation set.
-    n_samples = len(train_dataset)
-    n_valid = int(val_split * n_samples)
-    n_train = n_samples - n_valid
     train_dataset, valid_dataset = random_split(
-        train_dataset, (n_train, n_valid), torch.Generator().manual_seed(val_split_seed)
+        train_dataset,
+        ((1 - val_split), val_split),
+        torch.Generator().manual_seed(val_split_seed),
     )
     return train_dataset, valid_dataset, test_dataset
 
 
 def get_num_workers() -> int:
-    """Gets the optimal number of DatLoader workers to use in the current job."""
+    """Gets the optimal number of DataLoader workers to use in the current job."""
     if "SLURM_CPUS_PER_TASK" in os.environ:
         return int(os.environ["SLURM_CPUS_PER_TASK"])
     if hasattr(os, "sched_getaffinity"):
         return len(os.sched_getaffinity(0))
     return torch.multiprocessing.cpu_count()
 
 
+def load_checkpoint(checkpoint_dir: Path, **torch_load_kwargs) -> RunState | None:
+    """Loads the latest checkpoint if possible, otherwise returns `None`."""
+    logger = logging.getLogger(__name__)
+
+    checkpoint_file = checkpoint_dir / CHECKPOINT_FILE_NAME
+    restart_count = int(os.environ.get("SLURM_RESTART_COUNT", 0))
+    if restart_count:
+        logger.info(
+            f"NOTE: This job has been restarted {restart_count} times by SLURM."
+        )
+
+    if not checkpoint_file.exists():
+        logger.debug(f"No checkpoint found in checkpoints dir ({checkpoint_dir}).")
+        if restart_count:
+            logger.warning(
+                RuntimeWarning(
+                    f"This job has been restarted {restart_count} times by SLURM, but no "
+                    "checkpoint was found! This either means that your checkpointing code is "
+                    "broken, or that the job did not reach the checkpointing portion of your "
+                    "training loop."
+                )
+            )
+        return None
+
+    torch_load_kwargs.setdefault("weights_only", False)
+    checkpoint_state: dict = torch.load(checkpoint_file, **torch_load_kwargs)
+
+    missing_keys = set(checkpoint_state.keys()) - RunState.__required_keys__
+    if missing_keys:
+        warnings.warn(
+            RuntimeWarning(
+                f"Checkpoint at {checkpoint_file} is missing the following keys: {missing_keys}. "
+                f"Ignoring this checkpoint."
+            )
+        )
+        return None
+
+    logger.debug(f"Resuming from the checkpoint file at {checkpoint_file}")
+    state: RunState = checkpoint_state  # type: ignore
+    return state
+
+
+def save_checkpoint(checkpoint_dir: Path, is_best: bool, state: RunState):
+    """Saves a checkpoint with the current state of the run in the checkpoint dir.
+
+    The best checkpoint is also updated if `is_best` is `True`.
+
+    Parameters
+    ----------
+    checkpoint_dir: The checkpoint directory.
+    is_best: Whether this is the best checkpoint so far.
+    state: The dictionary containing all the things to save.
+    """
+    checkpoint_dir.mkdir(parents=True, exist_ok=True)
+    checkpoint_file = checkpoint_dir / CHECKPOINT_FILE_NAME
+
+    # Use a unique ID to avoid any potential collisions.
+    unique_id = uuid.uuid1()
+    temp_checkpoint_file = checkpoint_file.with_suffix(f".temp{unique_id}")
+
+    torch.save(state, temp_checkpoint_file)
+    os.replace(temp_checkpoint_file, checkpoint_file)
+
+    if is_best:
+        best_checkpoint_file = checkpoint_file.with_name("model_best.pth")
+        temp_best_checkpoint_file = best_checkpoint_file.with_suffix(
+            f".temp{unique_id}"
+        )
+        shutil.copyfile(checkpoint_file, temp_best_checkpoint_file)
+        os.replace(temp_best_checkpoint_file, best_checkpoint_file)
+
+
 if __name__ == "__main__":
     main()

Running this example

$ sbatch job.sh