import random
import deep_learning
import numpy as np
import torch
import torch.accelerator as accl
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils
import torchvision.datasets as datasets
import torchvision.transforms.v2 as v2
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.classification import MulticlassAccuracy
print('PyTorch version:', torch.__version__)Chapter 2.7 Training Loop in PyTorch: Connecting Data, Model, and Optimizer
The previous sections introduced several core components in PyTorch training:
nn.Moduleorganizes model parameters and forward computation;- a loss function turns predictions and targets into a scalar loss;
- an optimizer updates parameters according to gradients;
state_dictsaves the state of modules and optimizers.
When writing experiment code, we usually do not rebuild a training loop from scratch every time. Training code contains many fixed operations: setting random seeds, selecting a device, moving data to that device, switching between training and evaluation modes, and tracking loss and metrics. None of these operations is complicated by itself, but without a stable template, later experiments can easily fail on small details.
This section therefore does not start from “what is training”. Instead, it builds a training template that will be reused in later chapters. It is not the only correct implementation, but it gives us a default convention for subsequent experiments.
2.7.1 Why a Training Template Is Needed
A minimal PyTorch training step usually looks like this:
logits = model(X)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()These lines are the core. However, they are not enough for a real experiment. We also need to handle several boundary issues:
- Can random initialization and data shuffling remain as consistent as possible across runs?
- If the current machine has a GPU, MPS, or another accelerator, how should the code select the device automatically?
- During training, how should we compute loss for the whole epoch instead of only looking at the last batch?
- During validation, how should we disable gradient recording to avoid wasting memory?
- How should metrics such as accuracy and F1 be accumulated across batches?
These questions are not tied to the model architecture, but almost every experiment has to deal with them.
We first organize these conventions into a fixed template. Later, when training an MLP, CNN, Transformer, or another model, we can replace the model, dataset, and metrics while keeping the overall training framework unchanged.
2.7.2 Fixing Random Seeds
Deep learning experiments have many sources of randomness:
- random initialization of model parameters;
- random dataset splitting;
- random shuffling in
DataLoader; - random layers such as Dropout;
- nondeterministic implementations of some low-level operators.
If these sources are not controlled, running the same code twice may produce slightly different results. This makes experimental comparison harder, because it becomes unclear whether the change comes from the model modification or from randomness.
In later experiments, we usually fix the random seed first:
def set_seed(
seed: int = 42,
*,
deterministic: bool = True,
benchmark: bool = False,
warn_only: bool = True,
) -> torch.Generator:
random.seed(seed)
np.random.seed(seed)
torch_rng = torch.manual_seed(seed)
torch.use_deterministic_algorithms(deterministic, warn_only=warn_only)
# These two lines will explain in the future chapter
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = benchmark
return torch_rngThis sets the random seeds for Python, NumPy, and PyTorch.
The line:
torch.use_deterministic_algorithms(True, warn_only=True)means that PyTorch should use deterministic algorithms whenever possible. Here, warn_only=True is used because some operations may not have deterministic implementations. For teaching code, a warning is enough. If strict determinism is required, warn_only can be set to False, so PyTorch raises an error as soon as it encounters a nondeterministic operation.
One caveat remains: fixing random seeds does not guarantee bit-for-bit identical results in every environment. Different hardware, PyTorch versions, and low-level libraries may still introduce small differences. More precisely, fixing random seeds makes experiments as reproducible as practical, rather than mathematically identical in every setting.
Call it once first:
torch_rng = set_seed(42)2.7.3 Selecting the Compute Device
Many older PyTorch tutorials select the device like this:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')This pattern is common and easy to understand, but it only considers CUDA and CPU.
PyTorch now provides a more unified torch.accelerator interface. Its goal is to expose CUDA, MPS, XPU, MTIA, and other accelerators through one entry point. With this interface, the code does not need to focus only on CUDA.
We can write a small helper:
def get_default_device() -> torch.device:
device = accl.current_accelerator(check_available=True)
if device is not None:
return device
return torch.device('cpu')
device = get_default_device()
print('Using device:', device)After obtaining the device, one rule is worth remembering:
Data should live where the model lives.
If the model is on a GPU, the input data must also be on the same GPU. If the model is on the CPU, the input data should also be on the CPU. Otherwise, PyTorch does not know where to perform the computation and raises an error.
The common pattern is:
model = model.to(device)
X = X.to(device)
y = y.to(device)The training template below will handle this inside the functions.
2.7.4 Preparing a Small Classification Task
To keep the focus on the training template itself, use the MNIST dataset directly.
root = deep_learning.get_data_root()
transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
ds_rng = torch.Generator().manual_seed(42)
train_ds = datasets.MNIST(root, train=True, transform=transform, download=True)
train_ds, val_ds = utils.random_split(train_ds, [50000, 10000], generator=ds_rng)
train_dl = utils.DataLoader(train_ds, batch_size=64, shuffle=True)
val_dl = utils.DataLoader(val_ds, batch_size=128, shuffle=False)There is a small detail here:
generator=torch.Generator().manual_seed(42)random_split itself is random. Although torch.manual_seed() has already fixed the global random seed, the global RNG depends on how far it has already been advanced. To keep the data split independent of earlier model initialization, random tensors, or data augmentation, we pass a separate Generator to random_split.
Next, define a small MLP:
model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 10),
)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)The model outputs 10 logits, one for each digit. Since this is a multiclass classification task, the loss function is nn.CrossEntropyLoss.
2.7.5 Tracking Metrics with TorchMetrics
During training, loss is the objective actually optimized by the optimizer. When observing model behavior, we often need other metrics as well, such as accuracy, precision, recall, F1, and AUROC.
The simplest accuracy can be written by hand:
y_pred = logits.argmax(dim=1)
accuracy = (y_pred == y).float().mean()This is fine only for very simple cases. In real tasks, metrics may need to accumulate state across batches and may also need synchronization across processes in distributed training. At that point, manually implemented metrics become easy to get wrong.
TorchMetrics (Detlefsen et al. 2022) is designed for this problem. It organizes metrics as objects similar to nn.Module and provides a unified interface:
metric.update(...) -> accumulate results from the current batch
metric.compute() -> compute the final metric from accumulated state
metric.reset() -> clear state before the next round
For example, multiclass accuracy can be defined as:
metric = MulticlassAccuracy(num_classes=3).to(device)Another important property of TorchMetrics is that a metric is stateful.
For example, accuracy internally maintains state similar to “number of correct predictions” and “total number of samples”. Each time a batch is processed, update() accumulates the current batch statistics. Only when compute() is called do we get the accuracy for the whole epoch.
Therefore, remember to call reset() at the right time. Otherwise, statistics from the current epoch may be mixed with state left over from the previous epoch.
2.7.6 One Training Epoch: train
Now write the logic for training one epoch as a function.
def train(
model: nn.Module,
dataloader: utils.DataLoader[tuple[Tensor, Tensor]],
loss_fn: nn.Module,
optimizer: optim.Optimizer,
metric: Metric,
device: torch.device,
) -> tuple[float, float]:
model.train()
metric.reset()
total_loss = 0.0
for X, y in dataloader:
X = X.to(device)
y = y.to(device)
logits = model(X)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
# We use `detach()` here to avoid keeping the computation graph
# for the metric update.
metric.update(logits.detach(), y)
avg_loss = total_loss / len(dataloader)
avg_metric = metric.compute().item()
return avg_loss, avg_metricSeveral details in this function are worth noting.
First, the function starts with:
model.train()This switches the model to training mode. For Linear and ReLU, training mode and evaluation mode behave the same. For modules such as Dropout and BatchNorm, however, the two modes behave differently. Placing model.train() inside the training function is therefore safer.
Second, each batch is moved to the device selected earlier:
X = X.to(device)
y = y.to(device)This ensures that the inputs and the model are on the same device, avoiding device mismatch errors during computation.
Then come the three core lines:
optimizer.zero_grad()
loss.backward()
optimizer.step()As discussed earlier, gradients in PyTorch are accumulated into each parameter’s .grad field by default. Therefore, old gradients must be cleared before each backward pass. Otherwise, the gradient from the current batch will be mixed with the gradient from the previous batch.
Finally, loss is accumulated with:
total_loss += loss.item()Here, total_loss is averaged over batches. This is an approximation: the last batch may contain fewer samples than the previous batches. For example, if there are 100 training samples and the batch size is 32, the first three batches contain 32 samples each, while the last batch contains only 4 samples. Directly averaging over batches gives the last batch too much weight. A stricter implementation would multiply each batch loss by its batch size, accumulate that value, and divide by the total number of samples at the end. Another option is to set drop_last=True when constructing the DataLoader. Here we keep the simple average to keep the template concise.
2.7.7 One Validation Epoch: evaluate
The validation loop is similar to the training loop, with three key differences:
- use
model.eval()to switch the model to evaluation mode; - do not call
backward()oroptimizer.step(), because validation does not update parameters; - use
torch.inference_mode()to avoid recording the computation graph.
def evaluate(
model: nn.Module,
dataloader: utils.DataLoader[tuple[Tensor, Tensor]],
loss_fn: nn.Module,
metric: Metric,
device: torch.device,
) -> tuple[float, float]:
model.eval()
metric.reset()
total_loss = 0.0
with torch.inference_mode():
for X, y in dataloader:
X = X.to(device)
y = y.to(device)
logits = model(X)
loss = loss_fn(logits, y)
total_loss += loss.item()
metric.update(logits, y)
avg_loss = total_loss / len(dataloader)
avg_metric = metric.compute().item()
return avg_loss, avg_metricThis function uses torch.inference_mode() because no gradients are needed afterward.
2.7.8 Connecting Training and Validation
Finally, write a small train_and_evaluate function that connects multi-epoch training and validation.
def train_and_evaluate(
model: nn.Module,
train_dl: utils.DataLoader[tuple[Tensor, Tensor]],
val_dl: utils.DataLoader[tuple[Tensor, Tensor]],
loss_fn: nn.Module,
optimizer: optim.Optimizer,
metric: Metric,
val_metric: Metric,
num_epochs: int,
device: torch.device | None = None,
) -> None:
if device is None:
device = get_default_device()
model.to(device)
metric.to(device)
val_metric.to(device)
for epoch in range(1, num_epochs + 1):
loss, score = train(
model=model,
dataloader=train_dl,
loss_fn=loss_fn,
optimizer=optimizer,
metric=metric,
device=device,
)
val_loss, val_score = evaluate(
model=model,
dataloader=val_dl,
loss_fn=loss_fn,
metric=val_metric,
device=device,
)
w = len(str(num_epochs))
print(
f'Epoch [{epoch:{w}d}/{num_epochs:{w}d}] '
f'| loss: {loss:.4f} '
f'| metric: {score:.4f} '
f'| val_loss: {val_loss:.4f} '
f'| val_metric: {val_score:.4f}'
)The training and validation loops receive two separate metric objects:
metric: Metric
val_metric: MetricThis is because metrics have their own internal state. Although reset() is called each time, training and validation are two different accumulation processes. Keeping them separate makes the logic clearer.
Now run the complete template:
torch_rng = set_seed(42)
device = get_default_device()
model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 10),
)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
metric = MulticlassAccuracy(num_classes=10)
val_metric = MulticlassAccuracy(num_classes=10)
train_and_evaluate(
model=model,
train_dl=train_dl,
val_dl=val_dl,
loss_fn=loss_fn,
optimizer=optimizer,
metric=metric,
val_metric=val_metric,
num_epochs=5,
device=device,
)If everything works, the training loss should decrease and accuracy should increase.
This is the basic structure used by many later experiments. We will not explain it from scratch each time, but the overall flow will remain the same:
- Set the random seed and select the device;
- Define the model, loss function, optimizer, and metric;
- Call
train_and_evaluate()for training and validation.
2.7.9 What This Template Does Not Handle
This template is intentionally simple. It covers the most common structure for single-machine, single-device supervised learning tasks, but it does not handle more complex training engineering topics, such as:
- automatic mixed precision training;
- gradient clipping;
- learning rate schedulers;
- checkpoint saving and resuming;
- multi-GPU or multi-process training;
- profiling and performance optimization;
torch.compilecompilation optimization.
All of these topics are important, but they are not part of the minimal training template.
The next section will discuss checkpoints first: if training is interrupted halfway, which states should be saved so that training can resume correctly. More advanced training engineering topics can be introduced later, after the model architectures become more complex.
2.7.10 Summary
This section did not explain how to train a model from scratch. Instead, it organized a PyTorch training template that will be reused later.
First, set_seed controls randomness, and torch.use_deterministic_algorithms reduces nondeterministic behavior where possible. Then, torch.accelerator selects the available accelerator without limiting the code to CUDA. Next, TorchMetrics treats metrics such as accuracy as stateful objects.
Inside the training function, the core flow remains zero_grad(), backward(), and step(). A complete training loop also needs to handle model mode, device transfer, loss averaging, and metric state reset. The validation function switches to eval mode and uses torch.inference_mode() to disable unnecessary gradient recording.
Once this template is clear, later experiments can focus less on the training loop itself and more on model architecture, loss functions, and experimental design.