Skip to content

Checkpointing

Checkpointing is the process of periodically saving the full state of a training run so that it can be resumed later. Without checkpointing, any interruption means starting from scratch and wasting compute and time. This is especially useful to manage long-running jobs that can then be split into more robust and smaller ones.

A complete checkpoint generally needs more than just the model weights. It also includes the optimizer state, the current epoch number and the states of random number generator (for Python, NumPy, PyTorch, etc.). Saving all of these states ensures that a resumed run produces exactly the same result as an uninterrupted one.

You can also save multiple intermediate checkpoints for analysis or to keep different restart points. Be mindful of storage space, as checkpoints can consume a lot of it, especially if your model and optimizer states are large. Make sure to monitor your storage usage and clean up old checkpoints as needed.

Warning

You can use the $SCRATCH storage for checkpoints, but be aware that it is not backed up and may be cleared periodically, so it's not suitable for long-term storage. After the end of the training, you should always keep a copy of the most valuable checkpoints (e.g., the best model or the latest checkpoint) in a more permanent storage.

Checkpointing enables you to:

  • Resume training from the last saved checkpoint rather than from scratch after a preemption or time-limit, avoiding wasted compute.
  • Request shorter wall-clock time allocations, which are easier to schedule on the cluster.
  • Access preemptible partitions where jobs get high-priority resources but may be interrupted.

Prerequisites

Make sure to read the following sections of the documentation before using this example:

Other resources:

Example

The full source code for this example is available on the mila-docs GitHub repository.

job.sh

Two SBATCH directives are added:

  • #SBATCH --requeue: allows the job to be requeued with the same job ID.
  • #SBATCH --signal=B:TERM@300: sends SIGTERM to the job 5 minutes before its time limit expires, giving the Python process time to exit cleanly.
 # distributed/single_gpu/job.sh -> good_practices/checkpointing/job.sh
 #!/bin/bash
 #SBATCH --ntasks=1
 #SBATCH --ntasks-per-node=1
 #SBATCH --cpus-per-task=4
 #SBATCH --gpus-per-task=l40s:1
 ## Or --gpus-per-task=rtx8000:1    to request a different GPU model
 ## Or --gpus-per-task=1            for any GPU model
 #SBATCH --mem-per-gpu=16G
 #SBATCH --time=00:15:00
+#SBATCH --requeue
+#SBATCH --signal=B:TERM@300 # tells the controller to send SIGTERM to the job 5
+                            # min before its time ends to give it a chance for
+                            # better cleanup.

 # Exit on error
 set -e

 # Echo time and hostname into log
 echo "Date:     $(date)"
 echo "Hostname: $(hostname)"
+echo "Job has been preempted $SLURM_RESTART_COUNT times."

 # To make your code as much reproducible as possible with
 # `torch.use_deterministic_algorithms(True)`, uncomment the following block:
 ## === Reproducibility ===
 ## Be warned that this can make your code slower. See
 ## https://pytorch.org/docs/stable/notes/randomness.html#cublas-and-cudnn-deterministic-operations
 ## for more details.
 # export CUBLAS_WORKSPACE_CONFIG=:4096:8
 ## === Reproducibility (END) ===

 # Stage dataset into $SLURM_TMPDIR
 mkdir -p $SLURM_TMPDIR/data
 cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
 # General-purpose alternatives combining copy and unpack:
 #     unzip   /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
 #     tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/

 # Execute Python script
 # Use the `--offline` option of `uv run` on clusters without internet access on compute nodes.
 # Using the `--locked` option can help make your experiments easier to reproduce (it forces
 # your uv.lock file to be up to date with the dependencies declared in pyproject.toml).
 srun uv run python main.py

pyproject.toml

[project]
name = "checkpointing-example"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11,<3.14"
dependencies = [
    "rich>=14.0.0",
    "torch>=2.7.1",
    "torchvision>=0.22.1",
    "tqdm>=4.67.1",
]

main.py

The additions are mainly into three parts:

Checkpoint loading: Before the training begins, load_checkpoint checks whether a checkpoint file already exists at the expected path. If it does, the model, optimizer and all random states are restored, the epoch counter starts from the next epoch. Because SLURM preserves the job ID across requeues, the restarted job automatically finds the checkpoint left by the previous run. If no checkpoint is found, the training starts from scratch.

Checkpoint saving: At the end of each epoch, save_checkpoint writes a RunState dictionary to $SCRATCH/checkpointing_example/<job_id>/checkpoints/checkpoint.pth. The RunState includes the model weights, optimizer state, current epoch, best validation accuracy, and all random number generator states.

Signal handling: The signal_handler function is registered for SIGTERM (preemption or time-limit expiration) and will interrupt the training when receiving the signals and run the clean up code (e.g., to manage a W&B run or to automatically restart the job). Since the last complete checkpoint was saved at the end of the previous epoch, at most one epoch of progress is lost when training is interrupted.

 # distributed/single_gpu/main.py -> good_practices/checkpointing/main.py
-"""Single-GPU training example."""
+"""Checkpointing example."""

 import argparse
 import logging
 import os
 import random
+import shutil
+import signal
 import sys
+import uuid
+import warnings
 from pathlib import Path
+from types import FrameType
+from typing import Any, TypedDict

 import numpy as np
 import rich.logging
 import torch
 from torch import Tensor, nn
 from torch.nn import functional as F
 from torch.utils.data import DataLoader, random_split
 from torchvision import transforms
 from torchvision.datasets import CIFAR10
 from torchvision.models import resnet18
 from tqdm import tqdm

+logger = logging.getLogger(__name__)
+
+SCRATCH = Path(os.environ["SCRATCH"])
+SLURM_TMPDIR = Path(os.environ["SLURM_TMPDIR"])
+SLURM_JOBID = os.environ["SLURM_JOBID"]
+
+CHECKPOINT_FILE_NAME = "checkpoint.pth"
+

 # To make your code as much reproducible as possible, uncomment the following
 # block:
 ## === Reproducibility ===
 ## Be warned that this can make your code slower. See
 ## https://pytorch.org/docs/stable/notes/randomness.html#cublas-and-cudnn-deterministic-operations
 ## for more details.
 # torch.use_deterministic_algorithms(True)
 ## === Reproducibility (END) ===


+class RunState(TypedDict):
+    """Typed dictionary containing the state of the training run which is saved at each epoch."""
+    epoch: int
+    best_acc: float
+    model_state: dict[str, Tensor]
+    optimizer_state: dict[str, Tensor]
+
+    random_state: tuple[Any, ...]
+    numpy_random_state: dict[str, Any]
+    torch_random_state: Tensor
+    torch_cuda_random_state: list[Tensor]
+
+
+def signal_handler(signum: int, _: FrameType | None):
+    """Called before the job gets pre-empted or reaches the time-limit.
+
+    This should run quickly. Performing a full checkpoint here mid-epoch is not recommended.
+    """
+    signal_enum = signal.Signals(signum)
+    logger.error(f"Job received a {signal_enum.name} signal!")
+
+    # Perform quick actions that will help the job resume later.
+    # If you use Weights & Biases: https://docs.wandb.ai/guides/runs/resuming#preemptible-sweeps
+    # if wandb.run:
+    #     wandb.mark_preempting()
+
+    # If you want to directly requeue the job:
+    # subprocess.run(["scontrol", "requeue", SLURM_JOBID])
+    # Be careful to avoid this if you already have the `--requeue` sbatch argument, as SLURM will
+    # automatically requeue the job on preemption without needing this call.
+
+    # If you want to return a custom exit code:
+    # sys.exit(<custom_code>)
+
+
 def main():
     # Use an argument parser so we can pass hyperparameters from the command line.
     parser = argparse.ArgumentParser(description=__doc__)
     parser.add_argument("--epochs", type=int, default=10)
     parser.add_argument("--learning-rate", type=float, default=5e-4)
     parser.add_argument("--weight-decay", type=float, default=1e-4)
     parser.add_argument("--batch-size", type=int, default=128)
+    parser.add_argument(
+        "--run-dir", type=Path, default=SCRATCH / "checkpointing_example" / SLURM_JOBID
+    )
     parser.add_argument("--seed", type=int, default=42)
     args = parser.parse_args()

     epochs: int = args.epochs
     learning_rate: float = args.learning_rate
     weight_decay: float = args.weight_decay
     batch_size: int = args.batch_size
+    run_dir: Path = args.run_dir
     seed: int = args.seed

+    checkpoint_dir = run_dir / "checkpoints"
+    start_epoch: int = 0
+    best_acc: float = 0.0
+
     # Seed the random number generators as early as possible for reproducibility
     random.seed(seed)
     np.random.seed(seed)
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)

     # Check that the GPU is available
     assert torch.cuda.is_available() and torch.cuda.device_count() > 0
     device = torch.device("cuda", 0)

     # Setup logging (optional, but much better than using print statements)
     # Uses the `rich` package to make logs pretty.
     logging.basicConfig(
         level=logging.INFO,
         format="%(message)s",
         handlers=[
             rich.logging.RichHandler(
                 markup=True,
                 console=rich.console.Console(
                     # Allower wider log lines in sbatch output files than on the terminal.
                     width=120 if not sys.stdout.isatty() else None
                 ),
             )
         ],
     )

-    logger = logging.getLogger(__name__)
-
     # Create a model and move it to the GPU.
     model = resnet18(num_classes=10)
     model.to(device=device)

     optimizer = torch.optim.AdamW(
         model.parameters(), lr=learning_rate, weight_decay=weight_decay
     )

+    # Try to resume from a checkpoint, if one exists.
+    checkpoint: RunState | None = load_checkpoint(checkpoint_dir, map_location=device)
+    if checkpoint:
+        start_epoch = checkpoint["epoch"] + 1  # +1 to start at the next epoch.
+        best_acc = checkpoint["best_acc"]
+        model.load_state_dict(checkpoint["model_state"])
+        optimizer.load_state_dict(checkpoint["optimizer_state"])
+        random.setstate(checkpoint["random_state"])
+        np.random.set_state(checkpoint["numpy_random_state"])
+        # NOTE: Need to move those tensors to CPU before they can be loaded.
+        torch.random.set_rng_state(checkpoint["torch_random_state"].cpu())
+        torch.cuda.random.set_rng_state_all(
+            t.cpu() for t in checkpoint["torch_cuda_random_state"]
+        )
+        logger.info(
+            f"Resuming training at epoch {start_epoch} (best_acc={best_acc:.2%})."
+        )
+    else:
+        logger.info(f"No checkpoints found in {checkpoint_dir}. Training from scratch.")
+
     # Setup CIFAR10
     num_workers = get_num_workers()
-    dataset_path = Path(os.environ.get("SLURM_TMPDIR", ".")) / "data"
+    dataset_path = (SLURM_TMPDIR or Path("..")) / "data"
+
     train_dataset, valid_dataset, test_dataset = make_datasets(str(dataset_path))
     train_dataloader = DataLoader(
         train_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=True,
+        # generator=torch.Generator().manual_seed(seed),
     )
     valid_dataloader = DataLoader(
         valid_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=False,
+        # generator=torch.Generator().manual_seed(seed),
     )
     _test_dataloader = DataLoader(  # NOTE: Not used in this example.
         test_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=False,
     )

-    # Checkout the "checkpointing and preemption" example for more info!
-    logger.debug("Starting training from scratch.")
+    # Before getting pre-empted and requeued or reaching the end of the time limit.
+    signal.signal(
+        signal.SIGTERM, signal_handler
+    )

-    for epoch in range(epochs):
+    for epoch in range(start_epoch, epochs):
         logger.debug(f"Starting epoch {epoch}/{epochs}")

         # Set the model in training mode (important for e.g. BatchNorm and Dropout layers)
         model.train()

         # NOTE: using a progress bar from tqdm because it's nicer than using `print`.
         progress_bar = tqdm(
             total=len(train_dataloader),
             desc=f"Train epoch {epoch}",
             disable=not sys.stdout.isatty(),  # Disable progress bar in non-interactive environments.
         )

         # Training loop
         for batch in train_dataloader:
             # Move the batch to the GPU before we pass it to the model
             batch = tuple(item.to(device) for item in batch)
             x, y = batch

             # Forward pass
             logits: Tensor = model(x)

             loss = F.cross_entropy(logits, y)

             optimizer.zero_grad()
             loss.backward()
             optimizer.step()

             # Calculate some metrics:
             n_correct_predictions = logits.detach().argmax(-1).eq(y).sum()
             n_samples = y.shape[0]
             accuracy = n_correct_predictions / n_samples

             logger.debug(f"Accuracy: {accuracy.item():.2%}")
             logger.debug(f"Average Loss: {loss.item()}")

             # Advance the progress bar one step and update the progress bar text.
             progress_bar.update(1)
             progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
         progress_bar.close()

         val_loss, val_accuracy = validation_loop(model, valid_dataloader, device)
         logger.info(
             f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}"
         )

+        # Remember best accuracy and save the current state.
+        is_best = val_accuracy > best_acc
+        best_acc = max(val_accuracy, best_acc)
+
+        if checkpoint_dir is not None:
+            save_checkpoint(
+                checkpoint_dir,
+                is_best,
+                RunState(
+                    epoch=epoch,
+                    model_state=model.state_dict(),
+                    optimizer_state=optimizer.state_dict(),
+                    random_state=random.getstate(),
+                    numpy_random_state=np.random.get_state(legacy=False),
+                    torch_random_state=torch.random.get_rng_state(),
+                    torch_cuda_random_state=torch.cuda.random.get_rng_state_all(),
+                    best_acc=best_acc,
+                ),
+            )
+
     print("Done!")


 @torch.no_grad()
 def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device):
     model.eval()

     total_loss = 0.0
     n_samples = 0
     correct_predictions = 0

     for batch in dataloader:
         batch = tuple(item.to(device) for item in batch)
         x, y = batch

         logits: Tensor = model(x)
         loss = F.cross_entropy(logits, y)

         batch_n_samples = x.shape[0]
         batch_correct_predictions = logits.argmax(-1).eq(y).sum()

         total_loss += loss.item()
         n_samples += batch_n_samples
         correct_predictions += batch_correct_predictions

     accuracy = correct_predictions / n_samples
     return total_loss, accuracy


 def make_datasets(
     dataset_path: str,
     val_split: float = 0.1,
     val_split_seed: int = 42,
 ):
     """Returns the training, validation, and test splits for CIFAR10.

     NOTE: We don't use image transforms here for simplicity.
     Having different transformations for train and validation would complicate things a bit.
     Later examples will show how to do the train/val/test split properly when using transforms.
     """
     train_dataset = CIFAR10(
         root=dataset_path, transform=transforms.ToTensor(), download=True, train=True
     )
     test_dataset = CIFAR10(
         root=dataset_path, transform=transforms.ToTensor(), download=True, train=False
     )
     # Split the training dataset into a training and validation set.
-    n_samples = len(train_dataset)
-    n_valid = int(val_split * n_samples)
-    n_train = n_samples - n_valid
     train_dataset, valid_dataset = random_split(
-        train_dataset, (n_train, n_valid), torch.Generator().manual_seed(val_split_seed)
+        train_dataset,
+        ((1 - val_split), val_split),
+        torch.Generator().manual_seed(val_split_seed),
     )
     return train_dataset, valid_dataset, test_dataset


 def get_num_workers() -> int:
     """Gets the optimal number of DataLoader workers to use in the current job."""
     if "SLURM_CPUS_PER_TASK" in os.environ:
         return int(os.environ["SLURM_CPUS_PER_TASK"])
     if hasattr(os, "sched_getaffinity"):
         return len(os.sched_getaffinity(0))
     return torch.multiprocessing.cpu_count()


+def load_checkpoint(checkpoint_dir: Path, **torch_load_kwargs) -> RunState | None:
+    """Loads the latest checkpoint if possible, otherwise returns `None`."""
+    checkpoint_file = checkpoint_dir / CHECKPOINT_FILE_NAME
+    restart_count = int(os.environ.get("SLURM_RESTART_COUNT", 0))
+    if restart_count:
+        logger.info(
+            f"NOTE: This job has been restarted {restart_count} times by SLURM."
+        )
+
+    if not checkpoint_file.exists():
+        logger.debug(f"No checkpoint found in checkpoints dir ({checkpoint_dir}).")
+        if restart_count:
+            logger.warning(
+                RuntimeWarning(
+                    f"This job has been restarted {restart_count} times by SLURM, but no "
+                    "checkpoint was found! This either means that your checkpointing code is "
+                    "broken, or that the job did not reach the checkpointing portion of your "
+                    "training loop."
+                )
+            )
+        return None
+
+    torch_load_kwargs.setdefault("weights_only", False)
+    checkpoint_state: dict = torch.load(checkpoint_file, **torch_load_kwargs)
+
+    missing_keys = set(checkpoint_state.keys()) - RunState.__required_keys__
+    if missing_keys:
+        warnings.warn(
+            RuntimeWarning(
+                f"Checkpoint at {checkpoint_file} is missing the following keys: {missing_keys}. "
+                f"Ignoring this checkpoint."
+            )
+        )
+        return None
+
+    logger.debug(f"Resuming from the checkpoint file at {checkpoint_file}")
+    state: RunState = checkpoint_state  # type: ignore
+    return state
+
+
+def save_checkpoint(checkpoint_dir: Path, is_best: bool, state: RunState):
+    """Saves a checkpoint with the current state of the run in the checkpoint dir.
+
+    The best checkpoint is also updated if `is_best` is `True`.
+
+    Parameters:
+        checkpoint_dir (Path): The checkpoint directory.
+        is_best (bool): Whether this is the best checkpoint so far.
+        state (RunState): The dictionary containing all the things to save.
+    """
+    checkpoint_dir.mkdir(parents=True, exist_ok=True)
+    checkpoint_file = checkpoint_dir / CHECKPOINT_FILE_NAME
+
+    # Use a unique ID to avoid any potential collisions.
+    unique_id = uuid.uuid1()
+    temp_checkpoint_file = checkpoint_file.with_suffix(f".temp{unique_id}")
+
+    torch.save(state, temp_checkpoint_file)
+    os.replace(temp_checkpoint_file, checkpoint_file)
+
+    if is_best:
+        best_checkpoint_file = checkpoint_file.with_name("model_best.pth")
+        temp_best_checkpoint_file = best_checkpoint_file.with_suffix(
+            f".temp{unique_id}"
+        )
+        shutil.copyfile(checkpoint_file, temp_best_checkpoint_file)
+        os.replace(temp_best_checkpoint_file, best_checkpoint_file)
+
+
 if __name__ == "__main__":
     main()

Running this example

Submit the job script with:

$ sbatch job.sh

To cancel the job manually while allowing it to be requeued, send SIGTERM rather than the default SIGKILL:

$ scancel --signal=TERM <jobid>

To requeue a job with the same job ID and arguments:

$ scontrol requeue <jobid>

Note

Only jobs currently actively running, pending in queue or recently completed can be requeued.

Warning

A job cancelled with plain scancel (no --signal flag) is marked as FAILED by SLURM and cannot be requeued. Always pass --signal=TERM when cancelling a job you intend to resume.

Comments