Chapter 2.7 PyTorch 中的训练循环:把数据、模型和优化器连接起来

作者

Brench

发布于

2026-06-19

修改于

2026-06-19

前面几节里,我们已经分别认识了 PyTorch 训练中最重要的几个组件:

真正写实验代码时,我们不会每次都从零开始临时拼一遍训练循环。训练代码里有很多固定动作:设置随机种子、选择设备、把数据搬到设备上、切换训练和评估模式、统计 loss 和 metric。这些操作本身不复杂,但如果没有一套稳定模板,后续实验很容易在细节上出错。

因此,本节不再从”什么是训练”讲起,而是整理一套后面会反复使用的训练模板。它不是唯一正确的写法,但可以作为后续实验的默认约定。

import random

import deep_learning
import numpy as np
import torch
import torch.accelerator as accl
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils
import torchvision.datasets as datasets
import torchvision.transforms.v2 as v2
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.classification import MulticlassAccuracy

print('PyTorch version:', torch.__version__)

2.7.1 为什么需要训练模板

一个最小的 PyTorch 训练步骤通常长这样:

logits = model(X)
loss = loss_fn(logits, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

这几行当然是核心。但是,只靠它们还不够。实际实验中,我们还需要处理很多边界问题。例如:

  • 每次运行实验,随机初始化和数据打乱能不能尽量一致?
  • 当前机器有 GPU、MPS 或其他加速器时,代码应该怎么自动选择设备?
  • 训练时怎么统计整个 epoch 的 loss,而不是只看最后一个 batch?
  • 验证时怎么关闭梯度记录,避免浪费显存?
  • Accuracy、F1 这类指标要怎么跨 batch 累积?

这些问题和模型结构本身没有关系,但几乎每个实验都会遇到。

因此,我们先把这些约定整理成一个固定模板。后面训练 MLP、CNN、Transformer 或其他模型时,只需要替换模型、数据集和指标,整体训练框架可以保持不变。

2.7.2 固定随机数种子

深度学习实验里有很多随机性来源。例如:

  • 模型参数的随机初始化;
  • 数据集的随机划分;
  • DataLoader 的随机打乱;
  • Dropout 等随机层;
  • 某些底层算子的非确定性实现。

如果不控制这些随机性,同一段代码运行两次,结果可能会有一些差异。这在做实验对比时会比较麻烦,因为我们很难判断结果变化到底来自模型改动,还是来自随机因素。

所以后面的实验里,我们通常先固定随机种子:

def set_seed(
    seed: int = 42,
    *,
    deterministic: bool = True,
    benchmark: bool = False,
    warn_only: bool = True,
) -> torch.Generator:
    random.seed(seed)
    np.random.seed(seed)
    torch_rng = torch.manual_seed(seed)

    torch.use_deterministic_algorithms(deterministic, warn_only=warn_only)
    # These two lines will explain in the future chapter
    torch.backends.cudnn.deterministic = deterministic
    torch.backends.cudnn.benchmark = benchmark

    return torch_rng

这里分别设置了 Python、NumPy 和 PyTorch 的随机种子。

这一行:

torch.use_deterministic_algorithms(True, warn_only=True)

表示尽量使用确定性算法。这里设置 warn_only=True,是因为某些操作可能没有确定性实现。对于教学代码,我们用 warning 提醒我们就够了。如果需要严格确保确定性,可以把 warn_only 设置为 False,这样一旦遇到非确定性操作就直接报错。

不过要注意,固定随机种子并不意味着所有环境下结果都能完全一模一样。不同硬件、不同 PyTorch 版本、不同底层库,都可能带来细微差异。所以更准确地说,固定随机种子是为了让实验尽量可复现,而不是保证数学意义上的完全相同。

先调用一次:

torch_rng = set_seed(42)

2.7.3 选择计算设备

以前很多 PyTorch 教程会这样选择设备:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

这个写法很常见,也很好理解。但是它只考虑了 CUDA 和 CPU。

现在 PyTorch 提供了更统一的 torch.accelerator 接口。它的目标是把 CUDA、MPS、XPU、MTIA 等不同加速器统一到一个入口下面。这样我们写代码时,就不用只盯着 CUDA。

我们可以写一个小函数:

def get_default_device() -> torch.device:
    device = accl.current_accelerator(check_available=True)
    if device is not None:
        return device
    return torch.device('cpu')


device = get_default_device()
print('Using device:', device)

拿到设备之后,有一句话一定要记住:

模型在哪里,数据就要在哪里。

也就是说,如果模型在 GPU 上,输入数据也必须在同一个 GPU 上;如果模型在 CPU 上,输入数据也要在 CPU 上。否则 PyTorch 不知道应该在哪里完成计算,就会报错。

常见写法是:

model = model.to(device)
X = X.to(device)
y = y.to(device)

这一点后面的训练模板会直接在函数内部处理。

2.7.4 准备一个很小的分类任务

为了把注意力放在训练模板本身,这里直接使用 MNIST 数据集。

root = deep_learning.get_data_root()
transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
ds_rng = torch.Generator().manual_seed(42)

train_ds = datasets.MNIST(root, train=True, transform=transform, download=True)
train_ds, val_ds = utils.random_split(train_ds, [50000, 10000], generator=ds_rng)

train_dl = utils.DataLoader(train_ds, batch_size=64, shuffle=True)
val_dl = utils.DataLoader(val_ds, batch_size=128, shuffle=False)

这里有一个小细节:

generator=torch.Generator().manual_seed(42)

random_split 本身也有随机性。虽然我们前面已经用 torch.manual_seed() 固定了全局随机种子,但它依赖全局 RNG 当前走到了哪里。为了让数据划分不受前面模型初始化、随机张量、数据增强等操作影响,我们给它单独传入了一个 Generator

接着我们定义一个很小的 MLP:

model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

模型最后输出 10 个 logit,对应 10 个数字。因为是多分类任务,所以损失函数使用 nn.CrossEntropyLoss

2.7.5 用 TorchMetrics 统计指标

训练时,loss 是优化器真正优化的目标。但是观察模型效果时,我们经常还需要别的指标,例如 accuracy、precision、recall、F1、AUROC 等。

最简单的 accuracy 当然可以自己写:

y_pred = logits.argmax(dim=1)
accuracy = (y_pred == y).float().mean()

但是这样写只适合非常简单的情况。实际任务中,指标可能需要跨 batch 累积,也可能需要在分布式训练时跨进程同步。这个时候,手写指标就容易变得麻烦。

TorchMetrics (Detlefsen 等 2022年) 就是用来解决这个问题的。它把指标组织成类似 nn.Module 的对象,并且提供统一的接口:

metric.update(...)   -> 把当前 batch 的结果累积到内部状态里
metric.compute()     -> 根据累积状态计算最终指标
metric.reset()       -> 清空状态,开始下一轮统计

例如,多分类 accuracy 可以这样定义:

metric = MulticlassAccuracy(num_classes=3).to(device)

TorchMetrics 的另一个重要特点是:metric 是有状态的。

比如 accuracy 内部会维护类似“预测正确的样本数”和“总样本数”这样的状态。每处理一个 batch,update() 就会把当前 batch 的统计量累积进去。最后调用 compute() 时,才得到整个 epoch 的 accuracy。

所以,使用 metric 时一定要记得在合适的时候 reset()。否则,这个 epoch 的统计结果可能会混进上一个 epoch 的状态。

2.7.6 单轮训练:train

现在我们把训练一个 epoch 的逻辑写成函数。

def train(
    model: nn.Module,
    dataloader: utils.DataLoader[tuple[Tensor, Tensor]],
    loss_fn: nn.Module,
    optimizer: optim.Optimizer,
    metric: Metric,
    device: torch.device,
) -> tuple[float, float]:
    model.train()
    metric.reset()

    total_loss = 0.0

    for X, y in dataloader:
        X = X.to(device)
        y = y.to(device)

        logits = model(X)
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        # We use `detach()` here to avoid keeping the computation graph
        # for the metric update.
        metric.update(logits.detach(), y)

    avg_loss = total_loss / len(dataloader)
    avg_metric = metric.compute().item()
    return avg_loss, avg_metric

这个函数里有几个地方值得注意。

首先,函数一开始调用:

model.train()

这是为了让模型进入训练模式。对于 LinearReLU,训练模式和评估模式没有区别。但对于 DropoutBatchNorm 这类模块,训练模式和评估模式的行为是不一样的。因此,把 model.train() 放进训练函数内部会更保险。

其次,每个 batch 里都要先把数据移动到前面拿到的设备上:

X = X.to(device)
y = y.to(device)

这样可以保证输入和模型在同一个设备,避免后续计算出错。

然后是最核心的三行:

optimizer.zero_grad()
loss.backward()
optimizer.step()

前面已经讲过,PyTorch 中梯度默认会累积到参数的 .grad 上,所以每次反向传播之前都要先清空旧梯度。否则,当前 batch 的梯度会和上一个 batch 的梯度混在一起。

最后,统计 loss 时我们写的是:

total_loss += loss.item()

这里对 total_loss 按 batch 做了简单平均。这个写法有一个近似:最后一个 batch 的样本数可能和前面不一样。比如一共有 100 个训练样本,batch size 是 32,那么前三个 batch 是 32 个样本,最后一个 batch 只有 4 个样本。直接平均会给最后一个 batch 过高权重。更严谨的做法是把每个 batch 的 loss 乘以当前 batch 的样本数,累积到 total_loss 里,最后再除以总样本数,或者在使用 DataLoader 加载数据集时设置 drop_last=True。这里先采用简单平均,保持模板简洁。

2.7.7 单轮验证:evaluate

验证循环和训练循环很像,但有三个关键区别:

  • 使用 model.eval(),让模型进入评估模式;
  • 不调用 backward()optimizer.step(),因为验证不更新参数;
  • 使用 torch.inference_mode(),避免记录计算图。
def evaluate(
    model: nn.Module,
    dataloader: utils.DataLoader[tuple[Tensor, Tensor]],
    loss_fn: nn.Module,
    metric: Metric,
    device: torch.device,
) -> tuple[float, float]:
    model.eval()
    metric.reset()

    total_loss = 0.0

    with torch.inference_mode():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)

            logits = model(X)
            loss = loss_fn(logits, y)

            total_loss += loss.item()
            metric.update(logits, y)

    avg_loss = total_loss / len(dataloader)
    avg_metric = metric.compute().item()
    return avg_loss, avg_metric

这里使用的是 torch.inference_mode(),因为我们确保后续不再需要梯度。

2.7.8 把训练和验证连起来

最后,我们写一个简单的 train_and_evaluate 函数,把多轮训练和验证连起来。

def train_and_evaluate(
    model: nn.Module,
    train_dl: utils.DataLoader[tuple[Tensor, Tensor]],
    val_dl: utils.DataLoader[tuple[Tensor, Tensor]],
    loss_fn: nn.Module,
    optimizer: optim.Optimizer,
    metric: Metric,
    val_metric: Metric,
    num_epochs: int,
    device: torch.device | None = None,
) -> None:
    if device is None:
        device = get_default_device()

    model.to(device)
    metric.to(device)
    val_metric.to(device)

    for epoch in range(1, num_epochs + 1):
        loss, score = train(
            model=model,
            dataloader=train_dl,
            loss_fn=loss_fn,
            optimizer=optimizer,
            metric=metric,
            device=device,
        )
        val_loss, val_score = evaluate(
            model=model,
            dataloader=val_dl,
            loss_fn=loss_fn,
            metric=val_metric,
            device=device,
        )

        w = len(str(num_epochs))
        print(
            f'Epoch [{epoch:{w}d}/{num_epochs:{w}d}] '
            f'| loss: {loss:.4f} '
            f'| metric: {score:.4f} '
            f'| val_loss: {val_loss:.4f} '
            f'| val_metric: {val_score:.4f}'
        )

这里我给训练和验证分别传入了两个 metric:

metric: Metric
val_metric: Metric

这是因为 metric 有自己的内部状态。虽然我们每次都会调用 reset(),但训练和验证本来就是两段不同的统计过程,把它们分开会更清晰。

现在运行一下完整模板:

torch_rng = set_seed(42)
device = get_default_device()

model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

metric = MulticlassAccuracy(num_classes=10)
val_metric = MulticlassAccuracy(num_classes=10)

train_and_evaluate(
    model=model,
    train_dl=train_dl,
    val_dl=val_dl,
    loss_fn=loss_fn,
    optimizer=optimizer,
    metric=metric,
    val_metric=val_metric,
    num_epochs=5,
    device=device,
)

如果一切正常,我们应该能看到训练 loss 下降,accuracy 上升。

这就是后面很多实验会使用的基本结构。我们不会每次都重新解释它,但大体流程会保持一致:

  1. 先设置随机种子,选择设备;
  2. 定义模型、损失函数、优化器和 metric;
  3. 调用 train_and_evaluate() 进行训练和验证。

2.7.9 这套模板没有处理什么

这一节的模板故意保持简单。它已经覆盖了单机、单设备、普通监督学习任务里最常见的结构,但还没有处理更复杂的训练工程问题,例如:

  • 自动混合精度训练;
  • 梯度裁剪;
  • 学习率调度器;
  • Checkpoint 保存与恢复;
  • 多 GPU / 多进程训练;
  • Profiling 与性能优化;
  • torch.compile 编译优化。

这些内容都很重要,但它们不是最小训练模板的一部分。

下一节我们会先讨论 checkpoint:如果训练到一半中断了,应该保存哪些状态,才能继续接着训练。更高级的训练工程内容,可以等模型结构更复杂以后再展开。

2.7.10 本章小结

这一节我们没有从零讲如何训练一个模型,而是整理了一套后面会反复使用的 PyTorch 训练模板。

首先,我们用 set_seed 统一控制随机性,并用 torch.use_deterministic_algorithms 尽量减少非确定性来源。然后,我们用 torch.accelerator 选择当前可用的加速器,让代码不只局限于 CUDA。接着,我们介绍了 torchmetrics,把 accuracy 这类评价指标也当成有状态对象来管理。

在训练函数中,核心流程仍然是 zero_grad()backward()step()。不过,完整训练代码还需要处理模型模式、设备迁移、loss 加权平均和 metric 状态重置。验证函数则需要切换到 eval 模式,并用 torch.inference_mode() 关闭不必要的梯度记录。

理解这套模板之后,后面再训练更复杂的模型时,我们就不用反复纠结训练循环本身,而可以把重点放在模型结构、损失函数和实验设计上。

参考文献

Detlefsen, Nicki Skafte, Jiri Borovec, Justus Schock, 等. 2022年. 《TorchMetrics - Measuring Reproducibility in PyTorch》. Journal of Open Source Software 7 (70): 4101. https://doi.org/10.21105/joss.04101.

二次使用