Skip to content

Train Your First Model

This guide walks you through training a small model (ResNet18) on CIFAR-10 using a single GPU on the Mila cluster. You will use Mila's CIFAR-10 dataset, stage it into fast local storage, and run a Slurm batch job.

Prerequisites

What you will do

  • Train a ResNet18 on CIFAR-10 using a single GPU.
  • Use Mila's CIFAR-10 dataset in /network/datasets/cifar10/.
  • Stage data into $SLURM_TMPDIR for fast I/O during training.
  • Submit and monitor a batch job with sbatch.

Open VSCode on a compute node

Create the project directory on the cluster

From your local machine, create the project directory on the cluster so that mila code can open it (the path is on the cluster):

ssh mila 'mkdir -p CODE/train_first_model'

Start VSCode on a CPU node

For this step, we're only preparing code and job scriptsβ€”not actually running trainingβ€”so we'll use a CPU node for a faster to allocate and less resource-intensive editor session.

mila code CODE/train_first_model --alloc --cpus-per-task=2 --mem=16G --time=01:00:00
[17:35:21] Checking disk quota on $HOME...                                                                                                     disk_quota.py:31
[17:35:27] Disk usage: 85.34 / 100.00 GiB and 794022 / 1048576 files                                                                           disk_quota.py:211
[17:35:29] (mila) $ cd $SCRATCH && salloc --cpus-per-task=2 --mem=16G --time=01:00:00 --job-name=mila-code                                     compute_node.py:293
salloc: --------------------------------------------------------------------------------------------------
salloc: # Using default long partition
salloc: --------------------------------------------------------------------------------------------------
salloc: Granted job allocation 8888888
[17:35:30] Waiting for job 8888888 to start.                                                                                                   compute_node.py:315
[17:35:31] (localhost) $ code --new-window --wait --remote ssh-remote+cn-a003.server.mila.quebec /home/mila/u/username/CODE/train_first_model  local_v2.py:55

Create the project files

The job script does three things:

  1. #SBATCH directives β€” Request 1 GPU, 4 CPUs, 16G memory, 15 minutes.
  2. Data staging β€” Copy CIFAR-10 from /network/datasets/ into $SLURM_TMPDIR/data. Compute nodes read from $SLURM_TMPDIR much faster than from network storage.
  3. Run the training script β€” srun uv run python main.py runs your script inside the allocation.
job.sh
#!/bin/bash
#SBATCH --ntasks=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=4
#SBATCH --gpus-per-task=l40s:1
#SBATCH --mem-per-gpu=16G
#SBATCH --time=00:15:00

# Exit on error
set -e

# Echo time and hostname into log
echo "Date:     $(date)"
echo "Hostname: $(hostname)"

# To make your code as much reproducible as possible with
# `torch.use_deterministic_algorithms(True)`, uncomment the following block:
## === Reproducibility ===
## Be warned that this can make your code slower. See
## https://pytorch.org/docs/stable/notes/randomness.html#cublas-and-cudnn-deterministic-operations
## for more details.
# export CUBLAS_WORKSPACE_CONFIG=:4096:8
## === Reproducibility (END) ===

# Stage dataset into $SLURM_TMPDIR
mkdir -p $SLURM_TMPDIR/data
cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
# General-purpose alternatives combining copy and unpack:
#     unzip   /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
#     tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/

# Execute Python script
# Use the `--offline` option of `uv run` on clusters without internet access on compute nodes.
# Using the `--locked` option can help make your experiments easier to reproduce (it forces
# your uv.lock file to be up to date with the dependencies declared in pyproject.toml).
srun uv run python main.py
pyproject.toml
[project]
name = "single-gpu-example"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11,<3.14"
dependencies = [
    "rich>=14.0.0",
    "torch>=2.7.1",
    "torchvision>=0.22.1",
    "tqdm>=4.67.1",
]

The training script uses PyTorch to load CIFAR-10 from $SLURM_TMPDIR/data, train ResNet18, and log validation loss and accuracy per epoch. Create a file main.py with the following content:

main.py
"""Single-GPU training example."""

import argparse
import logging
import os
import random
import sys
from pathlib import Path

import numpy as np
import rich.logging
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from tqdm import tqdm


# To make your code as much reproducible as possible, uncomment the following
# block:
## === Reproducibility ===
## Be warned that this can make your code slower. See
## https://pytorch.org/docs/stable/notes/randomness.html#cublas-and-cudnn-deterministic-operations
## for more details.
# torch.use_deterministic_algorithms(True)
## === Reproducibility (END) ===


def main():
    # Use an argument parser so we can pass hyperparameters from the command line.
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--learning-rate", type=float, default=5e-4)
    parser.add_argument("--weight-decay", type=float, default=1e-4)
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    epochs: int = args.epochs
    learning_rate: float = args.learning_rate
    weight_decay: float = args.weight_decay
    batch_size: int = args.batch_size
    seed: int = args.seed

    # Seed the random number generators as early as possible for reproducibility
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Check that the GPU is available
    assert torch.cuda.is_available() and torch.cuda.device_count() > 0
    device = torch.device("cuda", 0)

    # Setup logging (optional, but much better than using print statements)
    # Uses the `rich` package to make logs pretty.
    logging.basicConfig(
        level=logging.INFO,
        format="%(message)s",
        handlers=[
            rich.logging.RichHandler(
                markup=True,
                console=rich.console.Console(
                    # Allower wider log lines in sbatch output files than on the terminal.
                    width=120 if not sys.stdout.isatty() else None
                ),
            )
        ],
    )

    logger = logging.getLogger(__name__)

    # Create a model and move it to the GPU.
    model = resnet18(num_classes=10)
    model.to(device=device)

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )

    # Setup CIFAR10
    num_workers = get_num_workers()
    dataset_path = Path(os.environ.get("SLURM_TMPDIR", ".")) / "data"
    train_dataset, valid_dataset, test_dataset = make_datasets(str(dataset_path))
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
    )
    valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
    )
    _test_dataloader = DataLoader(  # NOTE: Not used in this example.
        test_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
    )

    # Checkout the "checkpointing and preemption" example for more info!
    logger.debug("Starting training from scratch.")

    for epoch in range(epochs):
        logger.debug(f"Starting epoch {epoch}/{epochs}")

        # Set the model in training mode (important for e.g. BatchNorm and Dropout layers)
        model.train()

        # NOTE: using a progress bar from tqdm because it's nicer than using `print`.
        progress_bar = tqdm(
            total=len(train_dataloader),
            desc=f"Train epoch {epoch}",
            disable=not sys.stdout.isatty(),  # Disable progress bar in non-interactive environments.
        )

        # Training loop
        for batch in train_dataloader:
            # Move the batch to the GPU before we pass it to the model
            batch = tuple(item.to(device) for item in batch)
            x, y = batch

            # Forward pass
            logits: Tensor = model(x)

            loss = F.cross_entropy(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Calculate some metrics:
            n_correct_predictions = logits.detach().argmax(-1).eq(y).sum()
            n_samples = y.shape[0]
            accuracy = n_correct_predictions / n_samples

            logger.debug(f"Accuracy: {accuracy.item():.2%}")
            logger.debug(f"Average Loss: {loss.item()}")

            # Advance the progress bar one step and update the progress bar text.
            progress_bar.update(1)
            progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
        progress_bar.close()

        val_loss, val_accuracy = validation_loop(model, valid_dataloader, device)
        logger.info(
            f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}"
        )

    print("Done!")


@torch.no_grad()
def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device):
    model.eval()

    total_loss = 0.0
    n_samples = 0
    correct_predictions = 0

    for batch in dataloader:
        batch = tuple(item.to(device) for item in batch)
        x, y = batch

        logits: Tensor = model(x)
        loss = F.cross_entropy(logits, y)

        batch_n_samples = x.shape[0]
        batch_correct_predictions = logits.argmax(-1).eq(y).sum()

        total_loss += loss.item()
        n_samples += batch_n_samples
        correct_predictions += batch_correct_predictions

    accuracy = correct_predictions / n_samples
    return total_loss, accuracy


def make_datasets(
    dataset_path: str,
    val_split: float = 0.1,
    val_split_seed: int = 42,
):
    """Returns the training, validation, and test splits for CIFAR10.

    NOTE: We don't use image transforms here for simplicity.
    Having different transformations for train and validation would complicate things a bit.
    Later examples will show how to do the train/val/test split properly when using transforms.
    """
    train_dataset = CIFAR10(
        root=dataset_path, transform=transforms.ToTensor(), download=True, train=True
    )
    test_dataset = CIFAR10(
        root=dataset_path, transform=transforms.ToTensor(), download=True, train=False
    )
    # Split the training dataset into a training and validation set.
    n_samples = len(train_dataset)
    n_valid = int(val_split * n_samples)
    n_train = n_samples - n_valid
    train_dataset, valid_dataset = random_split(
        train_dataset, (n_train, n_valid), torch.Generator().manual_seed(val_split_seed)
    )
    return train_dataset, valid_dataset, test_dataset


def get_num_workers() -> int:
    """Gets the optimal number of DatLoader workers to use in the current job."""
    if "SLURM_CPUS_PER_TASK" in os.environ:
        return int(os.environ["SLURM_CPUS_PER_TASK"])
    if hasattr(os, "sched_getaffinity"):
        return len(os.sched_getaffinity(0))
    return torch.multiprocessing.cpu_count()


if __name__ == "__main__":
    main()

Submit the job

Using the VSCode terminal, submit the job:

sbatch job.sh
sbatch: --------------------------------------------------------------------------------------------------
sbatch: # Using default long partition
sbatch: --------------------------------------------------------------------------------------------------
Submitted batch job 8888888

You will see something like Submitted batch job 8888888. Note the job ID.

Monitor the job

  • Queue status: squeue --me

    squeue --me
    
       JOBID     USER    PARTITION           NAME  ST START_TIME             TIME NODES CPUS TRES_PER_N MIN_MEM NODELIST (REASON) COMMENT
     8888888 username         long         job.sh   R 2026-03-16T19:48       1:08     1    4        N/A     16G cn-l016 (None) (null)
    

  • Watch the output file: Once the job starts, a file slurm-<JOBID>.out will be created in which you'll find the log of the job being executed. Open it to watch the model being trained:

    slurm-JOBID.out
    Date:     Tue Mar 17 02:55:56 PM EDT 2026
    Hostname: cn-l023.server.mila.quebec
    Using CPython 3.12.11
    Creating virtual environment at: .venv
    Installed 36 packages in 1m 56s
    [03/17/26 14:58:37] INFO     Epoch 0: Val loss: 59.757 accuracy: 46.80%                                      main.py:152
    [03/17/26 14:58:40] INFO     Epoch 1: Val loss: 42.472 accuracy: 62.78%                                      main.py:152
    [03/17/26 14:58:42] INFO     Epoch 2: Val loss: 41.969 accuracy: 63.88%                                      main.py:152
    [03/17/26 14:58:45] INFO     Epoch 3: Val loss: 47.068 accuracy: 60.84%                                      main.py:152
    [03/17/26 14:58:48] INFO     Epoch 4: Val loss: 38.446 accuracy: 67.44%                                      main.py:152
    [03/17/26 14:58:50] INFO     Epoch 5: Val loss: 39.577 accuracy: 67.66%                                      main.py:152
    [03/17/26 14:58:53] INFO     Epoch 6: Val loss: 40.347 accuracy: 68.30%                                      main.py:152
    [03/17/26 14:58:56] INFO     Epoch 7: Val loss: 46.434 accuracy: 67.30%                                      main.py:152
    [03/17/26 14:58:58] INFO     Epoch 8: Val loss: 44.021 accuracy: 69.44%                                      main.py:152
    [03/17/26 14:59:01] INFO     Epoch 9: Val loss: 47.403 accuracy: 69.00%                                      main.py:152
    Done!
    

Key concepts

  • Data staging to $SLURM_TMPDIR β€” Network storage is shared and slower. Copying the dataset into $SLURM_TMPDIR at the start of the job gives the compute node fast local access for the rest of the run.
  • srun β€” Runs a command inside the allocated resources. In our script, srun uv run python main.py runs the training on the GPU node.

Comments