Chapter 2.8 Checkpoints in PyTorch: Resuming Training After Interruption

Author

Brench

Published

2026-06-19

Modified

2026-06-19

The previous section organized a training template that will be reused later. It connects the model, data, optimizer, and evaluation metrics, so later experiments do not need to repeat the same boilerplate code.

However, that template does not yet handle a practical issue: training can be interrupted.

For example, training may stop halfway because:

Restarting from scratch after every interruption is expensive. In real training runs, we therefore save model checkpoints at regular intervals.

A checkpoint is not only a saved trained model. It is a saved training state. After the program is interrupted, we can recreate the model and optimizer, load the saved state back, and continue training from the previous progress.

This section uses MNIST for a small experiment: train for a few epochs, simulate a program crash, then reload the checkpoint and continue training. The goal is to see how checkpoints recover the training state.

import signal
from pathlib import Path

import deep_learning
import deep_learning.trainingtools as dt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils
import torchvision.datasets as datasets
import torchvision.transforms.v2 as v2
from torch import Tensor
from torchmetrics.classification import MulticlassAccuracy

print('PyTorch version:', torch.__version__)

2.8.1 Why Saving Only Model Parameters Is Not Enough

We have already seen state_dict. For an nn.Module, state_dict stores all parameters and buffers inside the model. For example, the weight and bias of a linear layer, and the running mean and running variance in BatchNorm, all appear in the model’s state_dict.

If we only care about inference, saving the model’s state_dict is enough:

checkpoint = model.state_dict()
torch.save(checkpoint, 'model.pt')

Later, recreate the same model architecture and load the state_dict:

model = MyModel()
model.load_state_dict(torch.load('model.pt'))

This is suitable when training is finished and the model is only needed for prediction.

However, if we want to resume training from an interruption, saving only model parameters is usually not enough. The training state contains not only the model, but also the optimizer. For example, Adam maintains first-moment and second-moment estimates, and SGD with momentum maintains momentum buffers. These states are not stored in model.state_dict(). They are stored in optimizer.state_dict().

Therefore, when resuming training, we usually need to save at least:

  • model parameters and buffers;
  • internal optimizer state;
  • the epoch that has already been completed.

In code:

checkpoint = {
    'epoch': epoch,
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}

This kind of checkpoint is closer to a saved training state.

2.8.2 Training a Simple MLP

To demonstrate checkpoints, first define a simple MLP and train it on MNIST.

As in the previous section, set the random seed and get the current default device:

deep_learning.set_seed(42)
device = deep_learning.get_default_device()
print('Using device:', device)

Then load the MNIST dataset and split out a validation set:

root = deep_learning.get_data_root()
transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
ds_rng = torch.Generator().manual_seed(42)

train_ds = datasets.MNIST(root, train=True, transform=transform, download=True)
train_ds, val_ds = utils.random_split(train_ds, [50000, 10000], generator=ds_rng)
test_ds = datasets.MNIST(root, train=False, transform=transform, download=True)

train_dl = utils.DataLoader(train_ds, batch_size=64, shuffle=True)
val_dl = utils.DataLoader(val_ds, batch_size=128, shuffle=False)
test_dl = utils.DataLoader(test_ds, batch_size=128, shuffle=False)

Here, random_split still receives a separate Generator, so the data split is not affected by other random operations.

Next, define a small MLP for demonstrating checkpoint saving and loading:

class MLP(nn.Module):
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)

To avoid repeating initialization code later, define a helper function that creates the model, loss function, optimizer, and metrics:

def create_training_objects(
    lr: float = 1e-3,
) -> tuple[
    nn.Module,
    nn.Module,
    optim.Optimizer,
    MulticlassAccuracy,
    MulticlassAccuracy,
]:
    model = MLP(num_classes=10).to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    metric = MulticlassAccuracy(num_classes=10).to(device)
    val_metric = MulticlassAccuracy(num_classes=10).to(device)
    return model, loss_fn, optimizer, metric, val_metric

Notice that this example uses Adam. This makes it easier to explain later why resumed training should load optimizer state, not only model parameters.

2.8.3 Saving a Checkpoint

Now define the checkpoint functions. Saving a checkpoint is essentially putting the states needed for recovery into a dictionary and writing that dictionary to disk with torch.save.

def save_checkpoint(
    path: str | Path,
    epoch: int,
    model: nn.Module,
    optimizer: optim.Optimizer,
    history: list[dict[str, float]],
) -> None:
    checkpoint = {
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'history': history,
    }
    torch.save(checkpoint, path)

This saves four items:

  • epoch: the epoch that has already finished;
  • model: model parameters and buffers;
  • optimizer: internal optimizer state;
  • history: training logs recorded so far.

The most important items are model and optimizer. If only model parameters are saved, the recovered weights are correct, but the optimizer behaves like a newly created optimizer and starts accumulating its state from zero. For optimizers such as Adam, this changes the subsequent training trajectory.

2.8.4 Loading a Checkpoint

Loading a checkpoint is the reverse process: use torch.load to read the dictionary, then load each state back into the model and optimizer.

def load_checkpoint(
    path: str | Path,
    model: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device | None = None,
) -> tuple[int, list[dict[str, float]]]:
    checkpoint = torch.load(
        path,
        map_location=device,
        weights_only=True,
    )

    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    epoch = checkpoint['epoch']
    history = checkpoint['history']
    return epoch, history

There are two details here.

The first one is:

map_location=device

It maps tensors in the checkpoint to the current device. For example, a checkpoint may have been saved on a GPU, but the current machine may only have a CPU. map_location avoids device mismatch problems in this case.

The second one is:

weights_only=True

Newer PyTorch versions recommend using weights_only=True when loading checkpoints that contain only tensors and simple Python objects. This reduces the security risk from pickle deserialization. The checkpoint here only stores model state, optimizer state, epoch, and history, so this setting can be used. If a checkpoint stores custom class objects, weights_only=True may fail to load it. This is also why saving state_dict is usually preferred over saving the whole model object.

2.8.5 First Training Run: Simulating a Program Crash

Now train for a few epochs and save a checkpoint during the run.

checkpoint_path = Path(deep_learning.get_data_root()) / 'mnist-mlp-checkpoint.pt'
device = deep_learning.get_default_device()

model, loss_fn, optimizer, metric, val_metric = create_training_objects(lr=0.001)
history = []

Suppose we planned to train for 5 epochs, but the program crashed after epoch 3. We can simulate this crash in the training loop with signal.SIGINT:

num_epochs = 5
crash_after_epoch = 3

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = dt.train(
        model=model,
        dataloader=train_dl,
        loss_fn=loss_fn,
        optimizer=optimizer,
        metric=metric,
        device=device,
    )
    val_loss, val_acc = dt.evaluate(
        model=model,
        dataloader=val_dl,
        loss_fn=loss_fn,
        metric=val_metric,
        device=device,
    )

    record = {
        'epoch': epoch,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc,
    }
    history.append(record)

    width = len(str(num_epochs))
    print(
        f'Epoch [{epoch:0{width}d}/{num_epochs:0{width}d}] '
        f'| train_loss: {train_loss:.4f} '
        f'| train_acc: {train_acc:.4f} '
        f'| val_loss: {val_loss:.4f} '
        f'| val_acc: {val_acc:.4f}'
    )

    save_checkpoint(
        checkpoint_path,
        epoch=epoch,
        model=model,
        optimizer=optimizer,
        history=history,
    )

    if epoch == crash_after_epoch:
        try:
            signal.raise_signal(signal.SIGINT)
        except KeyboardInterrupt:
            print('ERROR - Kernel died while waiting for execute reply.')
            break

Here, a checkpoint is saved after every epoch. Even if the program stops, training can be restored to the state after the most recent completed epoch. In real training, we would not intentionally crash the program. This is only a demonstration of what checkpoints provide.

2.8.6 Recreating the Model and Optimizer

After a program crash, the in-memory model and optimizer are gone. Therefore, the first step in resuming training is not to call the training function directly, but to recreate a model and optimizer with the same structure:

model, loss_fn, optimizer, metric, val_metric = create_training_objects(lr=0.001)

At this point, model is randomly initialized, and optimizer is also new. They have not yet recovered the pre-crash state.

Next, load the checkpoint:

last_epoch, history = load_checkpoint(
    checkpoint_path,
    model=model,
    optimizer=optimizer,
    device=device,
)

print('Last finished epoch:', last_epoch)
print('Number of history records:', len(history))

After loading, the model parameters are restored to the state at the end of epoch last_epoch, and the optimizer state is restored as well. Continued training should therefore start from last_epoch + 1.

2.8.7 Resuming Training from a Checkpoint

Now finish the remaining epochs:

for epoch in range(last_epoch + 1, num_epochs + 1):
    train_loss, train_acc = dt.train(
        model=model,
        dataloader=train_dl,
        loss_fn=loss_fn,
        optimizer=optimizer,
        metric=metric,
        device=device,
    )
    val_loss, val_acc = dt.evaluate(
        model=model,
        dataloader=val_dl,
        loss_fn=loss_fn,
        metric=val_metric,
        device=device,
    )

    record = {
        'epoch': epoch,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc,
    }
    history.append(record)

    width = len(str(num_epochs))
    print(
        f'Epoch [{epoch:0{width}d}/{num_epochs:0{width}d}] '
        f'| train_loss: {train_loss:.4f} '
        f'| train_acc: {train_acc:.4f} '
        f'| val_loss: {val_loss:.4f} '
        f'| val_acc: {val_acc:.4f}'
    )

    save_checkpoint(
        checkpoint_path,
        epoch=epoch,
        model=model,
        optimizer=optimizer,
        history=history,
    )

Although a program crash was simulated in the middle, training can still continue to completion.

We can also evaluate the final model on the test set:

test_metric = MulticlassAccuracy(num_classes=10).to(device)
test_loss, test_acc = dt.evaluate(
    model=model,
    dataloader=test_dl,
    loss_fn=loss_fn,
    metric=test_metric,
    device=device,
)
model_name = type(model).__name__
print(f'[{model_name}] | test_loss: {test_loss:.4f} | test_acc: {test_acc:.4f}')

This is the basic purpose of a checkpoint:

A program can stop, but the training state does not have to be lost.

2.8.8 What If Only the Model Is Loaded

To understand the role of optimizer state more clearly, consider another recovery method:

model.load_state_dict(checkpoint['model'])

but without:

optimizer.load_state_dict(checkpoint['optimizer'])

This is not completely wrong. The model parameters are restored, and the code can continue training. However, the optimizer’s internal state is lost.

For plain SGD without momentum, the effect may be small. For Adam, AdamW, or SGD with momentum, the optimizer maintains state related to past gradients. If that state is lost, the resumed training trajectory is no longer equivalent to continuing from the original point.

More precisely:

  • restoring only model: model parameters are restored, but optimizer state is lost, so the training trajectory changes;
  • restoring model and optimizer: a more complete training state is restored, and the trajectory is closer to continuing from the original point.

If the goal is to load a pretrained model and start a new training task, loading only model parameters is usually enough. If the goal is to resume interrupted training, optimizer state should be saved and loaded together with the model.

2.8.9 What Else Can Be Saved

This section saves only the most basic fields:

checkpoint = {
    'epoch': epoch,
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'history': history,
}

In a more complete training project, a checkpoint may also save:

  • learning rate scheduler state;
  • GradScaler state, meaning the scaling factor used in automatic mixed precision training;
  • current global step;
  • best validation metric so far;
  • experiment configuration, such as learning rate, batch size, and model hyperparameters;
  • random number generator state;
  • dataloader state.

These belong to more complex training engineering. For most later examples, saving model, optimizer, and epoch is enough. Dataloader state is especially relevant only when training must resume from the middle of an epoch. Here, we use a simpler strategy: save one checkpoint after each epoch, and resume from the next epoch.

2.8.10 Summary

This section demonstrated the basic use of checkpoints with MNIST.

For inference, saving the model’s state_dict is usually enough, because inference depends on model parameters but not optimizer state.

When training needs to resume after an interruption, saving only model parameters is usually not enough. The optimizer also has state, such as Adam’s first-moment and second-moment estimates or SGD’s momentum buffer. Therefore, a minimal training checkpoint usually contains:

checkpoint = {
    'epoch': epoch,
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}

To resume training, recreate the same model and optimizer first, then load their states separately:

model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

Finally, continue training from epoch + 1.

At this point, the PyTorch basics form a complete training chain: autograd computes gradients, nn.Module organizes the model, the loss function defines the optimization objective, the optimizer updates parameters, the training template connects these components, and checkpoints save and restore the training state when training is interrupted.