Skip to content

Track Experiments with Weights & Biases (WandB)

Weights & Biases is an experiment tracking platform for logging metrics, hyperparameters, and artifacts from training runs. Mila members supervised by core professors can access the shared Mila organization on wandb.ai for team-level project visibility and collaboration.

Before you begin

Request access to the Mila WandB organization

Students supervised by core professors are eligible for the Mila organization on wandb.ai. Write to it-support@mila.quebec to request access.

What this guide covers

  • Sign in with single sign-on (SSO) using your @mila.quebec address.
  • Authenticate the WandB CLI on the cluster.
  • Initialize a run and log metrics in a training script.
  • Name runs, attach tags, and group related runs.
  • Configure Slurm job scripts for reliable WandB logging and run resumption.
  • Identify whether a training job is I/O-bound or compute-bound using WandB system metrics and step timing.

Sign in for the first time

WandB uses Mila's SSO provider. Signing in with your @mila.quebec address the first time links the account to the Mila organization.

Migrate an existing WandB account

Add your Mila email first to avoid a duplicate account

To avoid creating a duplicate account: add your @mila.quebec address to the existing WandB account and make it the primary email before following the steps below. See the WandB documentation on managing email addresses. Then log out from WandB before proceeding.

  1. Go to wandb.ai and click Sign in.
  2. Enter your @mila.quebec email address. The password field will disappear once a recognized SSO domain is detected.
  3. Click Log in — the browser will redirect you to the Mila SSO page.
  4. Select the mila.quebec identity provider. WandB will offer to link the existing account to the Mila organization.

Create a new account

Follow the same SSO steps above. At the account creation prompt, select Professional.

Which account type to select?

Select Professional at the account creation prompt. This unlocks team features required for the Mila organization. The Mila IT team manages organization-level billing, so no personal plan upgrade is required.

Authenticate the CLI on the cluster

Most WandB Python API calls require a valid API key stored in the environment.

Log in interactively

Install WandB as a tool on a login node, then authenticate:

uv tool install --upgrade wandb
wandb login
Resolved 20 packages in 371ms
Prepared 20 packages in 1.87s
Installed 20 packages in 2.16s
 + annotated-types==0.7.0
 + certifi==2026.2.25
 [...]
 + urllib3==2.6.3
 + wandb==0.25.1
Installed 2 executables: wandb, wb
wandb: Logging into https://api.wandb.ai.
wandb: Create a new API key at: https://wandb.ai/authorize?ref=models
wandb: Store your API key securely and do not share it.
wandb: Paste your API key and hit enter:
wandb: Appending key for api.wandb.ai to your netrc file: /home/mila/u/username/.netrc
wandb: W&B API key is configured. Use `wandb login --relogin` to force relogin

Tip

uv tool install --upgrade wandb is a one-time step per cluster. The API key is stored in ~/.netrc after wandb login; subsequent jobs do not need to re-run wandb login unless the key is rotated.

Offline mode

On clusters without outbound internet access on compute nodes, run in offline mode and sync after the job completes:

export WANDB_MODE=offline
srun uv run --offline python main.py

After the job finishes, sync the run from the login node:

wandb sync --sync-all

Tip

On DRAC clusters, loading the httpproxy module before the srun line is an alternative to offline mode:

module load httpproxy
srun uv run python main.py

Initialize and log a training run

Every WandB run starts with wandb.init(). Call wandb.finish() at the end of the script — it is called automatically at exit, but an explicit call is safer in Slurm jobs.

Initialize a run

import os
import wandb

wandb.init(
    project="wandb-example",                # (1)!
    name=os.environ.get("SLURM_JOB_ID"),    # (2)!
    id=os.environ.get("SLURM_JOB_ID"),      # (3)!
    resume="allow",                         # (4)!
    config=vars(args)
    | {f"env/{k}": v for k, v in os.environ.items() if k.startswith("SLURM")},  # (5)!
)

  1. Groups runs in the WandB UI. Use one project per research question.
  2. Human-readable display name shown in the WandB Runs table.
  3. Sets the unique run ID. resume="allow" uses this to find and resume the run if it was preempted. Setting it to the Slurm job ID also links the run to its log file (slurm-<JOBID>.out).
  4. Creates a new run if the ID does not exist, or resumes it if it does. Useful when combined with checkpointing to recover from preemption.
  5. Pass the full argparse namespace and SLURM environment variables for easier debugging. Every key becomes a searchable, filterable column in the WandB Runs table under Config.

Log metrics

Call wandb.log() inside the batch loop to record training metrics at each step:

1
2
3
4
5
6
wandb.log(
    {
        "train/accuracy": accuracy,
        "train/loss": loss
    }
)

Log validation metrics once per epoch, after the validation loop:

1
2
3
4
5
6
7
wandb.log(
    {
        "val/accuracy": val_accuracy,
        "val/loss": val_loss,
        "epoch": epoch
    }
)

Tip

Prefix metric names with train/ and val/. WandB groups metrics with matching prefixes automatically in the Charts panel.

Complete example

See the WandB setup example for a complete single-GPU training script with WandB logging integrated.

Organize runs

Three arguments help keep runs organized as a project grows. name= and tags= make individual runs easy to identify and filter in the Runs table, while group= clusters related runs — such as multi-seed runs or ablations — under a single expandable row.

wandb.init(
    project="wandb-example",
    name=os.environ.get("SLURM_JOB_ID"),
    id=os.environ.get("SLURM_JOB_ID"),
    group=os.environ.get("SLURM_ARRAY_JOB_ID"),     # (1)!
    resume="allow",
    tags=["example", "resnet18"],                   # (2)!
    config=vars(args)
    | {f"env/{k}": v for k, v in os.environ.items() if k.startswith("SLURM")},
)

  1. Using SLURM_ARRAY_JOB_ID automatically as the group clusters all jobs into a single expandable row.
  2. Labels runs for filtering in the Runs table.

Diagnose training bottlenecks

WandB records GPU utilization, CPU usage, and memory under the System tab of every run automatically, no extra code is required. These metrics are the first place to check when a training job is slower than expected.

Read system metrics

Open a run in the WandB UI and select the System tab. The GPU Utilization chart shows the fraction of time the GPU spent on active compute during each sampling interval.

Two patterns indicate different root causes:

  • Sustained utilization near 100% — the job is compute-bound. The GPU is the bottleneck; this is the expected state for well-configured training.
  • Low or oscillating utilization — the GPU idles while waiting for the next batch. The data pipeline cannot deliver batches fast enough; the job is I/O-bound.

Tip

Common fixes for an I/O bottleneck: increase num_workers in the DataLoader, enable pin_memory=True, or copy the dataset to $SLURM_TMPDIR before the job starts.

Full job scripts

The wandb_setup example provides a complete job script and training script with data staging and WandB integration:

job.sh
#!/bin/bash
#SBATCH --ntasks=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=4
#SBATCH --gpus-per-task=l40s:1
## Or --gpus-per-task=rtx8000:1    to request a different GPU model
## Or --gpus-per-task=1            for any GPU model
#SBATCH --mem-per-gpu=16G
#SBATCH --time=00:15:00

# Exit on error
set -e

# Echo time and hostname into log
echo "Date:     $(date)"
echo "Hostname: $(hostname)"

# To make your code as much reproducible as possible with
# `torch.use_deterministic_algorithms(True)`, uncomment the following block:
## === Reproducibility ===
## Be warned that this can make your code slower. See
## https://pytorch.org/docs/stable/notes/randomness.html#cublas-and-cudnn-deterministic-operations
## for more details.
# export CUBLAS_WORKSPACE_CONFIG=:4096:8
## === Reproducibility (END) ===

# Stage dataset into $SLURM_TMPDIR
mkdir -p $SLURM_TMPDIR/data
cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
# General-purpose alternatives combining copy and unpack:
#     unzip   /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
#     tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/

## On DRAC or PAICE clusters you can load this module to log to wandb:
# module load httpproxy
## Otherwise you can also do this:
# export WANDB_MODE=offline

# Execute Python script
# Use the `--offline` option of `uv run` on clusters without internet access on compute nodes.
# Using the `--locked` option can help make your experiments easier to reproduce (it forces
# your uv.lock file to be up to date with the dependencies declared in pyproject.toml).
srun uv run python main.py
main.py
"""Example job that uses Weights & Biases (wandb.ai)."""

import argparse
import logging
import os
import random
import sys
from pathlib import Path

import numpy as np
import rich.logging
import torch
import wandb
from torch import Tensor, nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from tqdm import tqdm


# To make your code as much reproducible as possible, uncomment the following
# block:
## === Reproducibility ===
## Be warned that this can make your code slower. See
## https://pytorch.org/docs/stable/notes/randomness.html#cublas-and-cudnn-deterministic-operations
## for more details.
# torch.use_deterministic_algorithms(True)
## === Reproducibility (END) ===


def main():
    # Use an argument parser so we can pass hyperparameters from the command line.
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--learning-rate", type=float, default=5e-4)
    parser.add_argument("--weight-decay", type=float, default=1e-4)
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    epochs: int = args.epochs
    learning_rate: float = args.learning_rate
    weight_decay: float = args.weight_decay
    batch_size: int = args.batch_size
    seed: int = args.seed

    # Seed the random number generators as early as possible for reproducibility
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Check that the GPU is available
    assert torch.cuda.is_available() and torch.cuda.device_count() > 0
    device = torch.device("cuda", 0)

    # Setup logging (optional, but much better than using print statements)
    # Uses the `rich` package to make logs pretty.
    logging.basicConfig(
        level=logging.INFO,
        format="%(message)s",
        handlers=[
            rich.logging.RichHandler(
                markup=True,
                console=rich.console.Console(
                    # Allower wider log lines in sbatch output files than on the terminal.
                    width=120 if not sys.stdout.isatty() else None
                ),
            )
        ],
    )

    logger = logging.getLogger(__name__)

    # To resume experiments with Wandb, we need to have code that can properly
    # handle checkpointing (see other minimalist example about "checkpointing").
    # We have to set the `id` of the experiment.
    # This specific example here does not do that.

    # Setup Wandb
    wandb.init(
        # Set the project where this run will be logged
        project="wandb-example",
        name=os.environ.get("SLURM_JOB_ID"),
        id=os.environ.get("SLURM_JOB_ID"),
        group=os.environ.get("SLURM_ARRAY_JOB_ID"),
        resume="allow",  # See https://docs.wandb.ai/guides/runs/resuming
        tags=["example", "resnet18"],
        # Track hyperparameters and run metadata
        config=vars(args)
        | {f"env/{k}": v for k, v in os.environ.items() if k.startswith("SLURM")},
    )

    # Create a model and move it to the GPU.
    model = resnet18(num_classes=10)
    model.to(device=device)

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )

    # Setup CIFAR10
    num_workers = get_num_workers()
    dataset_path = Path(os.environ.get("SLURM_TMPDIR", ".")) / "data"
    train_dataset, valid_dataset, test_dataset = make_datasets(str(dataset_path))
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
    )
    valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
    )
    _test_dataloader = DataLoader(  # NOTE: Not used in this example.
        test_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
    )

    # Checkout the "checkpointing and preemption" example for more info!
    logger.debug("Starting training from scratch.")

    for epoch in range(epochs):
        logger.debug(f"Starting epoch {epoch}/{epochs}")

        # Set the model in training mode (important for e.g. BatchNorm and Dropout layers)
        model.train()

        # NOTE: using a progress bar from tqdm because it's nicer than using `print`.
        progress_bar = tqdm(
            total=len(train_dataloader),
            desc=f"Train epoch {epoch}",
            disable=not sys.stdout.isatty(),  # Disable progress bar in non-interactive environments.
        )

        # Training loop
        for batch in train_dataloader:
            # Move the batch to the GPU before we pass it to the model
            batch = tuple(item.to(device) for item in batch)
            x, y = batch

            # Forward pass
            logits: Tensor = model(x)

            loss = F.cross_entropy(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Calculate some metrics:
            n_correct_predictions = logits.detach().argmax(-1).eq(y).sum()
            n_samples = y.shape[0]
            accuracy = n_correct_predictions / n_samples

            logger.debug(f"Accuracy: {accuracy.item():.2%}")
            logger.debug(f"Average Loss: {loss.item()}")

            # Log metrics with wandb
            wandb.log(
                {
                    "train/accuracy": accuracy,
                    "train/loss": loss,
                }
            )

            # Advance the progress bar one step and update the progress bar text.
            progress_bar.update(1)
            progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
        progress_bar.close()

        val_loss, val_accuracy = validation_loop(model, valid_dataloader, device)
        logger.info(
            f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}"
        )
        wandb.log(
            {
                "val/accuracy": val_accuracy,
                "val/loss": val_loss,
                "epoch": epoch,
            }
        )

    wandb.finish()
    print("Done!")


@torch.no_grad()
def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device):
    model.eval()

    total_loss = 0.0
    n_samples = 0
    correct_predictions = 0

    for batch in dataloader:
        batch = tuple(item.to(device) for item in batch)
        x, y = batch

        logits: Tensor = model(x)
        loss = F.cross_entropy(logits, y)

        batch_n_samples = x.shape[0]
        batch_correct_predictions = logits.argmax(-1).eq(y).sum()

        total_loss += loss.item()
        n_samples += batch_n_samples
        correct_predictions += batch_correct_predictions

    accuracy = correct_predictions / n_samples
    return total_loss, accuracy


def make_datasets(
    dataset_path: str,
    val_split: float = 0.1,
    val_split_seed: int = 42,
):
    """Returns the training, validation, and test splits for CIFAR10.

    NOTE: We don't use image transforms here for simplicity.
    Having different transformations for train and validation would complicate things a bit.
    Later examples will show how to do the train/val/test split properly when using transforms.
    """
    train_dataset = CIFAR10(
        root=dataset_path, transform=transforms.ToTensor(), download=True, train=True
    )
    test_dataset = CIFAR10(
        root=dataset_path, transform=transforms.ToTensor(), download=True, train=False
    )
    # Split the training dataset into a training and validation set.
    n_samples = len(train_dataset)
    n_valid = int(val_split * n_samples)
    n_train = n_samples - n_valid
    train_dataset, valid_dataset = random_split(
        train_dataset, (n_train, n_valid), torch.Generator().manual_seed(val_split_seed)
    )
    return train_dataset, valid_dataset, test_dataset


def get_num_workers() -> int:
    """Gets the optimal number of DatLoader workers to use in the current job."""
    if "SLURM_CPUS_PER_TASK" in os.environ:
        return int(os.environ["SLURM_CPUS_PER_TASK"])
    if hasattr(os, "sched_getaffinity"):
        return len(os.sched_getaffinity(0))
    return torch.multiprocessing.cpu_count()


if __name__ == "__main__":
    main()
pyproject.toml
[project]
name = "wandb-setup"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11,<3.14"
dependencies = [
    "rich>=14.0.0",
    "torch>=2.7.1",
    "torchvision>=0.22.1",
    "tqdm>=4.67.1",
    "wandb>=0.22.3",
]

Example run

Running these scripts produces a run like this wandb-example run.


Key concepts

wandb.init()
Starts a WandB run. The most commonly used arguments are project, name, id, group, config and resume.
config= (in wandb.init())
Stores a dictionary of hyperparameters alongside the run. Each key becomes a searchable, filterable column in the WandB Runs table. Pass config=vars(args) | {f"env/{k}": v for k, v in os.environ.items() if k.startswith("SLURM")} to capture the full argparse namespace and Slurm environment variables for easier debugging.
group= (in wandb.init())
Groups related runs under a single expandable row in the WandB Runs table. Useful for multi-seed runs or ablations. Pass SLURM_ARRAY_JOB_ID to group all tasks in a job array automatically.
wandb.log()
Records a dictionary of metric values at the current step. Call once per iteration or epoch.
WANDB_MODE
Controls logging mode. Set to offline on clusters without outbound internet access. Sync with wandb sync --sync-all after the job completes.
System metrics
GPU utilization, CPU usage, and memory stats collected automatically by WandB during a run. Visible under the System tab of a run in the WandB UI.
perf/ prefix
Convention for logging performance timing metrics (e.g., perf/data_load_s, perf/compute_s) separately from training metrics to support bottleneck diagnosis.

Comments

Ask AI