Chapter 2.8 PyTorch 中的 Checkpoint:中断训练后如何继续

作者

Brench

发布于

2026-06-19

修改于

2026-06-19

上一节整理了一套后面会反复使用的训练模板。它负责把模型、数据、优化器和评估指标组织起来,让后续实验不用反复写同一批样板代码。

不过,这个模板还没有处理一个实际问题:训练过程中可能会中断。

比如,模型训练到一半,可能会遇到这些情况:

如果每次中断都要重新开始训练,时间成本会很高。因此,实际训练中通常会定期保存模型检查点(checkpoint)

Checkpoint 的作用不是单纯保存一个已经训练好的模型,而是保存当前训练现场。程序中断之后,可以重新创建模型和优化器,再把保存的状态加载回来,接着之前的进度继续训练。

本节用 MNIST 做一个小实验:先训练几轮,模拟程序突然崩溃,然后重新加载 checkpoint 继续训练,观察 checkpoint 如何恢复训练现场。

import signal
from pathlib import Path

import deep_learning
import deep_learning.trainingtools as dt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils
import torchvision.datasets as datasets
import torchvision.transforms.v2 as v2
from torch import Tensor
from torchmetrics.classification import MulticlassAccuracy

print('PyTorch version:', torch.__version__)

2.8.1 为什么只保存模型参数还不够

前面我们已经见过 state_dict。对于一个 nn.Module,它的 state_dict 保存了模型里所有的参数和 buffer。例如,线性层的权重和偏置,BatchNorm 里的 running mean 和 running variance,都会出现在模型的 state_dict 里。

所以,如果我们只关心推理,可以只保存模型的 state_dict

checkpoint = model.state_dict()
torch.save(checkpoint, 'model.pt')

之后重新创建同样结构的模型,再加载 state_dict

model = MyModel()
model.load_state_dict(torch.load('model.pt'))

这适合已经完成训练、只想用模型做预测的场景。

但是,如果我们想从中断处继续训练,只保存模型参数通常还不够。因为训练现场里不只有模型,还有优化器。比如 Adam 优化器会维护一阶矩估计和二阶矩估计,带 momentum 的 SGD 会维护动量缓冲区。这些状态不会保存在 model.state_dict() 里,而是保存在 optimizer.state_dict() 里。

因此,恢复训练时,我们通常至少要保存:

  • 模型参数和 buffer;
  • 优化器内部状态;
  • 当前训练到第几个 epoch。

也就是:

checkpoint = {
    'epoch': epoch,
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}

这样保存下来的 checkpoint 才更像一个“训练现场”。

2.8.2 训练一个简单的 MLP

为了演示 checkpoint 的用法,我们先写一个简单的 MLP 来训练 MNIST。

首先,和上一节一样,我们设置随机数种子,并获取当前可用的设备:

deep_learning.set_seed(42)
device = deep_learning.get_default_device()
print('Using device:', device)

然后,我们加载 MNIST 数据集,并划分出训练集和验证集:

root = deep_learning.get_data_root()
transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
ds_rng = torch.Generator().manual_seed(42)

train_ds = datasets.MNIST(root, train=True, transform=transform, download=True)
train_ds, val_ds = utils.random_split(train_ds, [50000, 10000], generator=ds_rng)
test_ds = datasets.MNIST(root, train=False, transform=transform, download=True)

train_dl = utils.DataLoader(train_ds, batch_size=64, shuffle=True)
val_dl = utils.DataLoader(val_ds, batch_size=128, shuffle=False)
test_dl = utils.DataLoader(test_ds, batch_size=128, shuffle=False)

这里依然给 random_split 单独传入了一个 Generator。这样数据划分不会受其他随机操作影响。

接着,我们定义一个很小的 MLP,演示 checkpoint 的保存和恢复:

class MLP(nn.Module):
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)

为了避免后面重复写初始化代码,我们再写一个工具函数,用来创建模型、损失函数、优化器和指标:

def create_training_objects(
    lr: float = 1e-3,
) -> tuple[
    nn.Module,
    nn.Module,
    optim.Optimizer,
    MulticlassAccuracy,
    MulticlassAccuracy,
]:
    model = MLP(num_classes=10).to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    metric = MulticlassAccuracy(num_classes=10).to(device)
    val_metric = MulticlassAccuracy(num_classes=10).to(device)
    return model, loss_fn, optimizer, metric, val_metric

注意这里用的是 Adam。这样后面更容易说明为什么恢复训练时不应该只加载模型参数,还要加载优化器状态。

2.8.3 保存 checkpoint

现在开始写 checkpoint 相关的函数。保存 checkpoint 本质上就是把需要恢复的状态放进一个字典,然后用 torch.save 写到文件里。

def save_checkpoint(
    path: str | Path,
    epoch: int,
    model: nn.Module,
    optimizer: optim.Optimizer,
    history: list[dict[str, float]],
) -> None:
    checkpoint = {
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'history': history,
    }
    torch.save(checkpoint, path)

这里保存了四样东西:

  • epoch:当前已经训练完的 epoch;
  • model:模型参数和 buffer;
  • optimizer:优化器内部状态;
  • history:已经记录下来的训练日志。

其中最关键的是 modeloptimizer。如果只保存模型参数,恢复之后虽然模型权重是对的,但是优化器会像刚创建一样,从零开始积累自己的状态。对于 Adam 这类优化器来说,这会改变后续训练轨迹。

2.8.4 加载 checkpoint

加载 checkpoint 的过程和保存过程相反:先用 torch.load 读出字典,再分别把状态加载回模型和优化器。

def load_checkpoint(
    path: str | Path,
    model: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device | None = None,
) -> tuple[int, list[dict[str, float]]]:
    checkpoint = torch.load(
        path,
        map_location=device,
        weights_only=True,
    )

    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    epoch = checkpoint['epoch']
    history = checkpoint['history']
    return epoch, history

这里有两个小细节。

第一个是:

map_location=device

它的作用是把 checkpoint 里的张量加载到当前设备上。例如,checkpoint 可能是在 GPU 上保存的,但现在这台机器只有 CPU,那么 map_location 可以避免设备不匹配带来的问题。

第二个是:

weights_only=True

新版 PyTorch 更推荐用 weights_only=True 加载只包含张量和简单 Python 对象的 checkpoint。这样可以减少 pickle 反序列化带来的安全风险。我们的 checkpoint 只保存了模型状态、优化器状态、epoch 和 history,所以可以使用这个设置。如果 checkpoint 里保存了自定义类对象,那么 weights_only=True 可能无法加载。这也是为什么我们通常更推荐保存 state_dict,而不是直接保存整个模型对象。

2.8.5 第一次训练:模拟程序突然崩溃

现在我们先训练几个 epoch,并在中途保存 checkpoint。

checkpoint_path = Path(deep_learning.get_data_root()) / 'mnist-mlp-checkpoint.pt'
device = deep_learning.get_default_device()

model, loss_fn, optimizer, metric, val_metric = create_training_objects(lr=0.001)
history = []

假设我们本来打算训练 5 个 epoch,但训练到第 3 个 epoch 之后,程序突然崩溃了。我们可以在训练循环里用 signal.SIGINT 模拟这个崩溃:

num_epochs = 5
crash_after_epoch = 3

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = dt.train(
        model=model,
        dataloader=train_dl,
        loss_fn=loss_fn,
        optimizer=optimizer,
        metric=metric,
        device=device,
    )
    val_loss, val_acc = dt.evaluate(
        model=model,
        dataloader=val_dl,
        loss_fn=loss_fn,
        metric=val_metric,
        device=device,
    )

    record = {
        'epoch': epoch,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc,
    }
    history.append(record)

    width = len(str(num_epochs))
    print(
        f'Epoch [{epoch:0{width}d}/{num_epochs:0{width}d}] '
        f'| train_loss: {train_loss:.4f} '
        f'| train_acc: {train_acc:.4f} '
        f'| val_loss: {val_loss:.4f} '
        f'| val_acc: {val_acc:.4f}'
    )

    save_checkpoint(
        checkpoint_path,
        epoch=epoch,
        model=model,
        optimizer=optimizer,
        history=history,
    )

    if epoch == crash_after_epoch:
        try:
            signal.raise_signal(signal.SIGINT)
        except KeyboardInterrupt:
            print('ERROR - Kernel died while waiting for execute reply.')
            break

这里每个 epoch 结束后都会保存一次 checkpoint。这样即使程序中断,我们也至少能恢复到最近一个完整 epoch 之后的状态。当然,在真实训练中,我们不会主动让程序崩溃。这里这样写只是为了演示 checkpoint 的作用。

2.8.6 重新创建模型和优化器

程序崩溃之后,内存里的 modeloptimizer 都没了。所以恢复训练时,第一步不是直接继续调用训练函数,而是重新创建一份同样结构的模型和优化器:

model, loss_fn, optimizer, metric, val_metric = create_training_objects(lr=0.001)

此时这个 model 是随机初始化的,optimizer 也是全新的。它们还没有恢复到崩溃前的状态。

接下来加载 checkpoint:

last_epoch, history = load_checkpoint(
    checkpoint_path,
    model=model,
    optimizer=optimizer,
    device=device,
)

print('Last finished epoch:', last_epoch)
print('Number of history records:', len(history))

加载之后,模型参数恢复到了第 last_epoch 个 epoch 结束时的状态,优化器状态也恢复了。所以继续训练时,应该从 last_epoch + 1 开始。

2.8.7 从 checkpoint 继续训练

现在继续完成剩下的 epoch:

for epoch in range(last_epoch + 1, num_epochs + 1):
    train_loss, train_acc = dt.train(
        model=model,
        dataloader=train_dl,
        loss_fn=loss_fn,
        optimizer=optimizer,
        metric=metric,
        device=device,
    )
    val_loss, val_acc = dt.evaluate(
        model=model,
        dataloader=val_dl,
        loss_fn=loss_fn,
        metric=val_metric,
        device=device,
    )

    record = {
        'epoch': epoch,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc,
    }
    history.append(record)

    width = len(str(num_epochs))
    print(
        f'Epoch [{epoch:0{width}d}/{num_epochs:0{width}d}] '
        f'| train_loss: {train_loss:.4f} '
        f'| train_acc: {train_acc:.4f} '
        f'| val_loss: {val_loss:.4f} '
        f'| val_acc: {val_acc:.4f}'
    )

    save_checkpoint(
        checkpoint_path,
        epoch=epoch,
        model=model,
        optimizer=optimizer,
        history=history,
    )

这样,虽然中间模拟了一次程序崩溃,但最终训练仍然可以继续完成。

我们也可以在测试集上看一下最后的效果:

test_metric = MulticlassAccuracy(num_classes=10).to(device)
test_loss, test_acc = dt.evaluate(
    model=model,
    dataloader=test_dl,
    loss_fn=loss_fn,
    metric=test_metric,
    device=device,
)
model_name = type(model).__name__
print(f'[{model_name}] | test_loss: {test_loss:.4f} | test_acc: {test_acc:.4f}')

这就是 checkpoint 最基本的作用:

程序可以中断,但训练状态不一定丢失。

2.8.8 如果只加载模型,不加载优化器会怎样

为了更清楚地理解 optimizer state 的作用,我们可以想象另一种恢复方式:

model.load_state_dict(checkpoint['model'])

但是不执行:

optimizer.load_state_dict(checkpoint['optimizer'])

这样做不是完全错误。模型参数确实恢复了,代码也可以继续训练。但是优化器内部状态丢失了。

对于普通 SGD,如果没有 momentum,影响可能不大;但对于 Adam、AdamW、带 momentum 的 SGD,优化器内部会维护和过去梯度有关的状态。如果这些状态丢掉了,那么恢复之后的训练轨迹就不再等价于从原来的地方继续训练。

所以更准确地说:

  • 只恢复 model:恢复了模型参数,但优化器状态丢失了,训练轨迹会改变;
  • 恢复 model + optimizer:恢复了更完整的训练状态,训练轨迹更接近于从原来地方继续训练。

如果只是加载预训练模型,然后开始一个新的训练任务,通常只加载模型参数就够了。但如果目标是中断后接着训练,就应该一起保存和加载优化器状态。

2.8.9 还可以保存什么

这一节我们只保存了最基本的内容:

checkpoint = {
    'epoch': epoch,
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'history': history,
}

在更完整的训练项目里,checkpoint 还可能保存:

  • Learning rate scheduler 的状态;
  • GradScaler 的状态,也就是自动混合精度训练中的缩放因子;
  • 当前 global step;
  • 当前最好的验证指标;
  • 实验配置,例如学习率、batch size、模型超参数;
  • 随机数生成器状态;
  • Dataloader 的状态。

不过,这些都属于更复杂的训练工程问题。对于后面多数示例来说,保存 modeloptimizerepoch 已经够用。尤其是 dataloader 状态,只有当我们希望在一个 epoch 中途恢复训练时才比较重要。这里采用更简单的策略:每个 epoch 结束后保存一次 checkpoint,所以恢复时从下一个 epoch 开始。

2.8.10 本章小结

这一节我们用 MNIST 演示了 checkpoint 的基本用法。

如果只是为了推理,通常只需要保存模型的 state_dict。因为推理只依赖模型参数,不依赖优化器状态。

但是,如果要在训练中断之后继续训练,只保存模型参数通常还不够。优化器本身也有状态,例如 Adam 的一阶矩和二阶矩估计、SGD 的 momentum buffer 等。因此,一个最小的训练 checkpoint 通常会包含:

checkpoint = {
    'epoch': epoch,
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}

恢复训练时,我们需要先重新创建同样结构的模型和优化器,再分别加载它们的状态:

model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

最后,从 epoch + 1 开始继续训练。

到这里,PyTorch 基础部分已经形成了一条完整的训练链条:自动微分负责计算梯度,nn.Module 负责组织模型,loss function 负责定义优化目标,optimizer 负责更新参数,训练模板负责把这些组件连接起来,而 checkpoint 负责在训练中断时保存和恢复训练现场。

二次使用