Skip to content

Checkpointing

Prerequisites

Make sure to read the following sections of the documentation before using this example:

The full source code for this example is available on the mila-docs GitHub repository.

job.sh

 # distributed/single_gpu/job.sh -> good_practices/checkpointing/job.sh
 #!/bin/bash
 #SBATCH --ntasks=1
 #SBATCH --ntasks-per-node=1
 #SBATCH --cpus-per-task=4
 #SBATCH --gpus-per-task=l40s:1
 ## Or --gpus-per-task=rtx8000:1    to request a different GPU model
 ## Or --gpus-per-task=1            for any GPU model
 #SBATCH --mem-per-gpu=16G
 #SBATCH --time=00:15:00
+#SBATCH --requeue
+#SBATCH --signal=B:TERM@300 # tells the controller to send SIGTERM to the job 5
+                            # min before its time ends to give it a chance for
+                            # better cleanup. If you cancel the job manually,
+                            # make sure that you specify the signal as TERM like
+                            # so `scancel --signal=TERM <jobid>`.
+                            # https://dhruveshp.com/blog/2021/signal-propagation-on-slurm/

 # Exit on error
 set -e

 # Echo time and hostname into log
 echo "Date:     $(date)"
 echo "Hostname: $(hostname)"
+echo "Job has been preempted $SLURM_RESTART_COUNT times."

 # To make your code as much reproducible as possible with
 # `torch.use_deterministic_algorithms(True)`, uncomment the following block:
 ## === Reproducibility ===
 ## Be warned that this can make your code slower. See
 ## https://pytorch.org/docs/stable/notes/randomness.html#cublas-and-cudnn-deterministic-operations
 ## for more details.
 # export CUBLAS_WORKSPACE_CONFIG=:4096:8
 ## === Reproducibility (END) ===

 # Stage dataset into $SLURM_TMPDIR
 mkdir -p $SLURM_TMPDIR/data
 cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
 # General-purpose alternatives combining copy and unpack:
 #     unzip   /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
 #     tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/

 # Execute Python script
 # Use the `--offline` option of `uv run` on clusters without internet access on compute nodes.
 # Using the `--locked` option can help make your experiments easier to reproduce (it forces
 # your uv.lock file to be up to date with the dependencies declared in pyproject.toml).
-srun uv run python main.py
+# Here we use `exec` to ensure that the signals are received and handled in the Python process.
+exec srun uv run python main.py

pyproject.toml

[project]
name = "checkpointing-example"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11,<3.14"
dependencies = [
    "rich>=14.0.0",
    "torch>=2.7.1",
    "torchvision>=0.22.1",
    "tqdm>=4.67.1",
]

main.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
 # distributed/single_gpu/main.py -> good_practices/checkpointing/main.py
-"""Single-GPU training example."""
+"""Checkpointing example."""
+
+from __future__ import annotations

 import argparse
 import logging
 import os
 import random
+import shutil
+import signal
 import sys
+import uuid
+import warnings
+from logging import getLogger as get_logger
 from pathlib import Path
+from types import FrameType
+from typing import Any, TypedDict

 import numpy as np
 import rich.logging
 import torch
 from torch import Tensor, nn
 from torch.nn import functional as F
 from torch.utils.data import DataLoader, random_split
 from torchvision import transforms
 from torchvision.datasets import CIFAR10
 from torchvision.models import resnet18
 from tqdm import tqdm

+SCRATCH = Path(os.environ["SCRATCH"])
+SLURM_TMPDIR = Path(os.environ["SLURM_TMPDIR"])
+SLURM_JOBID = os.environ["SLURM_JOBID"]
+
+CHECKPOINT_FILE_NAME = "checkpoint.pth"
+

 # To make your code as much reproducible as possible, uncomment the following
 # block:
 ## === Reproducibility ===
 ## Be warned that this can make your code slower. See
 ## https://pytorch.org/docs/stable/notes/randomness.html#cublas-and-cudnn-deterministic-operations
 ## for more details.
 # torch.use_deterministic_algorithms(True)
 ## === Reproducibility (END) ===


+class RunState(TypedDict):
+    """Typed dictionary containing the state of the training run which is saved at each epoch.
+
+    Using type hints helps prevent bugs and makes your code easier to read for both humans and
+    machines (e.g. Copilot). This leads to less time spent debugging and better code suggestions.
+    """
+
+    epoch: int
+    best_acc: float
+    model_state: dict[str, Tensor]
+    optimizer_state: dict[str, Tensor]
+
+    random_state: tuple[Any, ...]
+    numpy_random_state: dict[str, Any]
+    torch_random_state: Tensor
+    torch_cuda_random_state: list[Tensor]
+
+
 def main():
     # Use an argument parser so we can pass hyperparameters from the command line.
     parser = argparse.ArgumentParser(description=__doc__)
     parser.add_argument("--epochs", type=int, default=10)
     parser.add_argument("--learning-rate", type=float, default=5e-4)
     parser.add_argument("--weight-decay", type=float, default=1e-4)
     parser.add_argument("--batch-size", type=int, default=128)
+    parser.add_argument(
+        "--run-dir", type=Path, default=SCRATCH / "checkpointing_example" / SLURM_JOBID
+    )
     parser.add_argument("--seed", type=int, default=42)
     args = parser.parse_args()

     epochs: int = args.epochs
     learning_rate: float = args.learning_rate
     weight_decay: float = args.weight_decay
     batch_size: int = args.batch_size
+    run_dir: Path = args.run_dir
     seed: int = args.seed

+    checkpoint_dir = run_dir / "checkpoints"
+    start_epoch: int = 0
+    best_acc: float = 0.0
+
     # Seed the random number generators as early as possible for reproducibility
     random.seed(seed)
     np.random.seed(seed)
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)

     # Check that the GPU is available
     assert torch.cuda.is_available() and torch.cuda.device_count() > 0
     device = torch.device("cuda", 0)

     # Setup logging (optional, but much better than using print statements)
     # Uses the `rich` package to make logs pretty.
     logging.basicConfig(
         level=logging.INFO,
         format="%(message)s",
         handlers=[
             rich.logging.RichHandler(
                 markup=True,
                 console=rich.console.Console(
                     # Allower wider log lines in sbatch output files than on the terminal.
                     width=120 if not sys.stdout.isatty() else None
                 ),
             )
         ],
     )

     logger = logging.getLogger(__name__)

-    # Create a model and move it to the GPU.
+    # Create a model.
     model = resnet18(num_classes=10)
+
+    # Move the model to the GPU.
     model.to(device=device)

     optimizer = torch.optim.AdamW(
         model.parameters(), lr=learning_rate, weight_decay=weight_decay
     )

-    # Setup CIFAR10
+    # Try to resume from a checkpoint, if one exists.
+    checkpoint: RunState | None = load_checkpoint(checkpoint_dir, map_location=device)
+    if checkpoint:
+        start_epoch = checkpoint["epoch"] + 1  # +1 to start at the next epoch.
+        best_acc = checkpoint["best_acc"]
+        model.load_state_dict(checkpoint["model_state"])
+        optimizer.load_state_dict(checkpoint["optimizer_state"])
+        random.setstate(checkpoint["random_state"])
+        np.random.set_state(checkpoint["numpy_random_state"])
+        # NOTE: Need to move those tensors to CPU before they can be loaded.
+        torch.random.set_rng_state(checkpoint["torch_random_state"].cpu())
+        torch.cuda.random.set_rng_state_all(
+            t.cpu() for t in checkpoint["torch_cuda_random_state"]
+        )
+        logger.info(
+            f"Resuming training at epoch {start_epoch} (best_acc={best_acc:.2%})."
+        )
+    else:
+        logger.info(f"No checkpoints found in {checkpoint_dir}. Training from scratch.")
+
+    # Setup the dataset
     num_workers = get_num_workers()
-    dataset_path = Path(os.environ.get("SLURM_TMPDIR", ".")) / "data"
+    dataset_path = (SLURM_TMPDIR or Path("..")) / "data"
+
     train_dataset, valid_dataset, test_dataset = make_datasets(str(dataset_path))
     train_dataloader = DataLoader(
         train_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=True,
+        # generator=torch.Generator().manual_seed(seed),
     )
     valid_dataloader = DataLoader(
         valid_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=False,
+        # generator=torch.Generator().manual_seed(seed),
     )
     _test_dataloader = DataLoader(  # NOTE: Not used in this example.
         test_dataset,
         batch_size=batch_size,
         num_workers=num_workers,
         shuffle=False,
     )

-    # Checkout the "checkpointing and preemption" example for more info!
-    logger.debug("Starting training from scratch.")
-
-    for epoch in range(epochs):
+    def signal_handler(signum: int, frame: FrameType | None):
+        """Called before the job gets pre-empted or reaches the time-limit.
+
+        This should run quickly. Performing a full checkpoint here mid-epoch is not recommended.
+        """
+        signal_enum = signal.Signals(signum)
+        logger.error(f"Job received a {signal_enum.name} signal!")
+        # Perform quick actions that will help the job resume later.
+        # If you use Weights & Biases: https://docs.wandb.ai/guides/runs/resuming#preemptible-sweeps
+        # if wandb.run:
+        #     wandb.mark_preempting()
+
+    signal.signal(
+        signal.SIGTERM, signal_handler
+    )  # Before getting pre-empted and requeued.
+    signal.signal(
+        signal.SIGUSR1, signal_handler
+    )  # Before reaching the end of the time limit.
+
+    for epoch in range(start_epoch, epochs):
         logger.debug(f"Starting epoch {epoch}/{epochs}")

-        # Set the model in training mode (important for e.g. BatchNorm and Dropout layers)
+        # Set the model in training mode (this is important for e.g. BatchNorm and Dropout layers)
         model.train()

-        # NOTE: using a progress bar from tqdm because it's nicer than using `print`.
+        # NOTE: using a progress bar from tqdm much nicer than using `print`s).
         progress_bar = tqdm(
             total=len(train_dataloader),
             desc=f"Train epoch {epoch}",
             disable=not sys.stdout.isatty(),  # Disable progress bar in non-interactive environments.
         )

         # Training loop
+        batch: tuple[Tensor, Tensor]
         for batch in train_dataloader:
             # Move the batch to the GPU before we pass it to the model
             batch = tuple(item.to(device) for item in batch)
             x, y = batch

             # Forward pass
             logits: Tensor = model(x)

             loss = F.cross_entropy(logits, y)

             optimizer.zero_grad()
             loss.backward()
             optimizer.step()

             # Calculate some metrics:
             n_correct_predictions = logits.detach().argmax(-1).eq(y).sum()
             n_samples = y.shape[0]
             accuracy = n_correct_predictions / n_samples

             logger.debug(f"Accuracy: {accuracy.item():.2%}")
             logger.debug(f"Average Loss: {loss.item()}")

-            # Advance the progress bar one step and update the progress bar text.
+            # Advance the progress bar one step, and update the text displayed in the progress bar.
             progress_bar.update(1)
             progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
         progress_bar.close()

         val_loss, val_accuracy = validation_loop(model, valid_dataloader, device)
         logger.info(
             f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}"
         )

+        # remember best accuracy and save the current state.
+        is_best = val_accuracy > best_acc
+        best_acc = max(val_accuracy, best_acc)
+
+        if checkpoint_dir is not None:
+            save_checkpoint(
+                checkpoint_dir,
+                is_best,
+                RunState(
+                    epoch=epoch,
+                    model_state=model.state_dict(),
+                    optimizer_state=optimizer.state_dict(),
+                    random_state=random.getstate(),
+                    numpy_random_state=np.random.get_state(legacy=False),
+                    torch_random_state=torch.random.get_rng_state(),
+                    torch_cuda_random_state=torch.cuda.random.get_rng_state_all(),
+                    best_acc=best_acc,
+                ),
+            )
+
     print("Done!")


 @torch.no_grad()
 def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device):
     model.eval()

     total_loss = 0.0
     n_samples = 0
     correct_predictions = 0

     for batch in dataloader:
         batch = tuple(item.to(device) for item in batch)
         x, y = batch

         logits: Tensor = model(x)
         loss = F.cross_entropy(logits, y)

         batch_n_samples = x.shape[0]
-        batch_correct_predictions = logits.argmax(-1).eq(y).sum()
+        batch_correct_predictions = logits.argmax(-1).eq(y).sum().item()

         total_loss += loss.item()
         n_samples += batch_n_samples
-        correct_predictions += batch_correct_predictions
+        correct_predictions += int(batch_correct_predictions)

     accuracy = correct_predictions / n_samples
     return total_loss, accuracy


 def make_datasets(
     dataset_path: str,
     val_split: float = 0.1,
     val_split_seed: int = 42,
 ):
     """Returns the training, validation, and test splits for CIFAR10.

     NOTE: We don't use image transforms here for simplicity.
     Having different transformations for train and validation would complicate things a bit.
     Later examples will show how to do the train/val/test split properly when using transforms.
     """
     train_dataset = CIFAR10(
         root=dataset_path, transform=transforms.ToTensor(), download=True, train=True
     )
     test_dataset = CIFAR10(
         root=dataset_path, transform=transforms.ToTensor(), download=True, train=False
     )
     # Split the training dataset into a training and validation set.
-    n_samples = len(train_dataset)
-    n_valid = int(val_split * n_samples)
-    n_train = n_samples - n_valid
     train_dataset, valid_dataset = random_split(
-        train_dataset, (n_train, n_valid), torch.Generator().manual_seed(val_split_seed)
+        train_dataset,
+        ((1 - val_split), val_split),
+        torch.Generator().manual_seed(val_split_seed),
     )
     return train_dataset, valid_dataset, test_dataset


 def get_num_workers() -> int:
-    """Gets the optimal number of DatLoader workers to use in the current job."""
+    """Gets the optimal number of DataLoader workers to use in the current job."""
     if "SLURM_CPUS_PER_TASK" in os.environ:
         return int(os.environ["SLURM_CPUS_PER_TASK"])
     if hasattr(os, "sched_getaffinity"):
         return len(os.sched_getaffinity(0))
     return torch.multiprocessing.cpu_count()


+def load_checkpoint(checkpoint_dir: Path, **torch_load_kwargs) -> RunState | None:
+    """Loads the latest checkpoint if possible, otherwise returns `None`."""
+    logger = logging.getLogger(__name__)
+
+    checkpoint_file = checkpoint_dir / CHECKPOINT_FILE_NAME
+    restart_count = int(os.environ.get("SLURM_RESTART_COUNT", 0))
+    if restart_count:
+        logger.info(
+            f"NOTE: This job has been restarted {restart_count} times by SLURM."
+        )
+
+    if not checkpoint_file.exists():
+        logger.debug(f"No checkpoint found in checkpoints dir ({checkpoint_dir}).")
+        if restart_count:
+            logger.warning(
+                RuntimeWarning(
+                    f"This job has been restarted {restart_count} times by SLURM, but no "
+                    "checkpoint was found! This either means that your checkpointing code is "
+                    "broken, or that the job did not reach the checkpointing portion of your "
+                    "training loop."
+                )
+            )
+        return None
+
+    torch_load_kwargs.setdefault("weights_only", False)
+    checkpoint_state: dict = torch.load(checkpoint_file, **torch_load_kwargs)
+
+    missing_keys = set(checkpoint_state.keys()) - RunState.__required_keys__
+    if missing_keys:
+        warnings.warn(
+            RuntimeWarning(
+                f"Checkpoint at {checkpoint_file} is missing the following keys: {missing_keys}. "
+                f"Ignoring this checkpoint."
+            )
+        )
+        return None
+
+    logger.debug(f"Resuming from the checkpoint file at {checkpoint_file}")
+    state: RunState = checkpoint_state  # type: ignore
+    return state
+
+
+def save_checkpoint(checkpoint_dir: Path, is_best: bool, state: RunState):
+    """Saves a checkpoint with the current state of the run in the checkpoint dir.
+
+    The best checkpoint is also updated if `is_best` is `True`.
+
+    Parameters
+    ----------
+    checkpoint_dir: The checkpoint directory.
+    is_best: Whether this is the best checkpoint so far.
+    state: The dictionary containing all the things to save.
+    """
+    checkpoint_dir.mkdir(parents=True, exist_ok=True)
+    checkpoint_file = checkpoint_dir / CHECKPOINT_FILE_NAME
+
+    # Use a unique ID to avoid any potential collisions.
+    unique_id = uuid.uuid1()
+    temp_checkpoint_file = checkpoint_file.with_suffix(f".temp{unique_id}")
+
+    torch.save(state, temp_checkpoint_file)
+    os.replace(temp_checkpoint_file, checkpoint_file)
+
+    if is_best:
+        best_checkpoint_file = checkpoint_file.with_name("model_best.pth")
+        temp_best_checkpoint_file = best_checkpoint_file.with_suffix(
+            f".temp{unique_id}"
+        )
+        shutil.copyfile(checkpoint_file, temp_best_checkpoint_file)
+        os.replace(temp_best_checkpoint_file, best_checkpoint_file)
+
+
 if __name__ == "__main__":
     main()

Running this example

$ sbatch job.sh

Comments

Ask AI