Chapter 2.6 PyTorch 中的优化器:从手动更新到参数组与状态管理

作者

Brench

发布于

2026-06-18

修改于

2026-06-18

前面几节已经说明:只要计算图被正确记录,调用 loss.backward() 之后,模型参数就会得到梯度。

但是,梯度本身不会自动修改参数。它只说明:如果想让 loss 变小,参数应该往哪个方向移动。真正根据梯度更新参数的是优化器(optimizer)

比如最简单的梯度下降可以写成:

\[ \theta \leftarrow \theta - \eta \nabla_\theta L \]

其中,\(\theta\) 是参数,\(\nabla_\theta L\) 是参数的梯度,\(\eta\) 是学习率。

不使用优化器时,也可以手动更新参数:

import torch
import torch.nn as nn
import torch.nn.functional as F

x = torch.randn(8, 3)
y = torch.randn(8, 1)

linear = nn.Linear(3, 1)

pred = linear(x)
loss = F.mse_loss(pred, y)
loss.backward()

lr = 0.1
with torch.no_grad():
    for param in linear.parameters():
        param -= lr * param.grad

这段代码可以工作,但后续会遇到几个工程问题:

torch.optim 要解决的就是这些问题。

本节从最基本的 optimizer.step() 开始,逐步说明 PyTorch 优化器在管理什么。

from pprint import pprint

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor

print('PyTorch version:', torch.__version__)

2.6.1 从手动更新到 optimizer.step()

我们先用一个简单的线性模型作为例子:

model = nn.Linear(3, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)

创建优化器时,最重要的是告诉它两件事:

  • 我们要更新哪些参数;
  • 我们想用什么规则更新这些参数。

这里的:

optimizer = optim.SGD(model.parameters(), lr=0.1)

这行代码表示:把 model.parameters() 返回的所有参数交给 SGD,并用学习率 0.1 更新它们。

一个最小的训练步骤通常长这样:

x = torch.randn(8, 3)
y = torch.randn(8, 1)

pred = model(x)
loss = F.mse_loss(pred, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print('Loss:', loss.item())

其中,下面三行最常见:

optimizer.zero_grad()
loss.backward()
optimizer.step()

它们分别对应训练中的三个动作:

  • 清空旧梯度,避免累加;
  • 根据当前 loss 计算新梯度,写到参数的 .grad 中;
  • 根据新梯度和优化器状态更新参数。

注意,backward() 只负责把梯度写到参数的 .grad 属性里,真正修改参数值的是 optimizer.step()

我们可以直接看一下更新前后参数有没有变化:

model = nn.Linear(3, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)

before = model.weight.detach().clone()

pred = model(x)
loss = F.mse_loss(pred, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

after = model.weight.detach().clone()

flag = before.allclose(after)
max_err = (before - after).abs().max().item()
print('Is the parameter unchanged?', flag)
print('Max absolute difference:', max_err)

这说明参数确实被优化器修改了。

2.6.2 为什么每次更新前要 zero_grad

一个容易忽略的细节是:PyTorch 中的梯度默认是累加的,而不是覆盖。也就是说,如果连续调用两次 backward(),第二次得到的梯度不会覆盖第一次,而是加到原来的 .grad 上。

用一个很小的例子观察这个行为:

w = torch.tensor([1.0], requires_grad=True)

loss1 = w ** 2
loss1.backward()
print('After first backward:', w.grad)

loss2 = w ** 2
loss2.backward()
print('After second backward:', w.grad)

第一次反向传播后,梯度是 2;第二次反向传播后,梯度变成了 4。不是因为新的梯度是 4,而是因为旧的梯度 2 和新的梯度 2 累加起来的和是 4。

这就是为什么训练循环里通常要写:

optimizer.zero_grad()
loss.backward()
optimizer.step()

如果不清空梯度,每个 batch 的梯度都会叠加到前面的 batch 上,参数更新就不再对应当前 batch 的 loss

当然,梯度累加并不总是错误。有时候我们会故意累加多个 mini-batch 的梯度,再统一更新一次参数。这叫做梯度累积(gradient accumulation)

例如:

model = nn.Linear(3, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)

optimizer.zero_grad()

for i in range(4):
    x = torch.randn(8, 3)
    y = torch.randn(8, 1)

    pred = model(x)
    loss = F.mse_loss(pred, y)
    loss = loss / 4
    loss.backward()

optimizer.step()

这里我们故意不在每个 mini-batch 后面清空梯度,而是让 4 个 mini-batch 的梯度累加起来,再调用一次 step()

所以,更准确地说,zero_grad() 不是 backward() 的必要搭配,而是针对 PyTorch 默认累加梯度这一行为的处理。如果本次更新只想使用当前 batch 的梯度,就要在 backward() 之前清空旧梯度。

2.6.3 set_to_none 是什么

optimizer.zero_grad() 默认会把参数的 .grad 置为 None

model = nn.Linear(3, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)

x = torch.randn(8, 3)
y = torch.randn(8, 1)

loss = F.mse_loss(model(x), y)
loss.backward()

flag1 = model.weight.grad is None
print('Whether grad is None before zero_grad?', flag1)

optimizer.zero_grad(set_to_none=True)
flag2 = model.weight.grad is None
print('Whether grad is None after zero_grad?', flag2)

有时也会看到:

optimizer.zero_grad(set_to_none=False)

这种写法会把梯度设成 0,而不是 None

model = nn.Linear(3, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)

loss = F.mse_loss(model(x), y)
loss.backward()

optimizer.zero_grad(set_to_none=False)
print(model.weight.grad)

两者都可以达到清空旧梯度的目的,但语义稍微不同:

  • .grad = None:表示这个参数目前还没有梯度;
  • .grad = 0:表示这个参数有梯度,只是梯度值为 0。

在大多数训练代码里,使用默认的 set_to_none=True 即可。它通常更省内存,也可以让 PyTorch 在下一次反向传播时重新分配梯度张量。不过,如果代码假设 .grad 一定是张量,而不是 None,就需要注意这个区别。

2.6.4 参数组:不同参数可以有不同学习率

到目前为止,我们把模型的所有参数都交给了同一个优化器,并使用同一个学习率:

optimizer = optim.SGD(model.parameters(), lr=1e-3)

但实际训练中,不同部分经常需要使用不同超参数。

例如,在微调预训练模型时,backbone 可能用较小学习率,最后的分类头用较大学习率。因为 backbone 已经学过很多通用特征,不想改得太快;而 head 是随机初始化的,需要更快学习当前任务。

这时可以使用参数组(parameter groups)

class TinyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(10, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
        )
        self.head = nn.Linear(32, 2)

    def forward(self, x: Tensor) -> Tensor:
        x = self.backbone(x)
        return self.head(x)


model = TinyModel()
optimizer = optim.SGD(
    [
        {'params': model.backbone.parameters(), 'lr': 1e-4},
        {'params': model.head.parameters(), 'lr': 1e-3},
    ],
    weight_decay=1e-2,
)

for i, group in enumerate(optimizer.param_groups):
    print(
        f'Parameter group {i}: '
        f'learning rate = {group["lr"]}, '
        f'weight decay = {group["weight_decay"]}'
    )

这里传给优化器的不再是一个简单的参数迭代器,而是一个由字典组成的列表。每个字典描述一组参数,以及这一组参数自己的优化超参数。

如果某个参数组没有显式设置某个超参数,就会使用优化器构造函数里的默认值。比如上面的两个参数组都没有单独设置 weight_decay,所以它们都会使用外层的 weight_decay=1e-2

参数组也常用于设置某些参数不做 weight decay。例如,很多训练代码会对 bias 和 normalization 层的参数关闭 weight decay:

def split_weight_decay_params(model: nn.Module):
    decay = []
    no_decay = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if name.endswith('bias') or 'norm' in name.lower():
            no_decay.append(param)
        else:
            decay.append(param)

    return [
        {'params': decay, 'weight_decay': 1e-2},
        {'params': no_decay, 'weight_decay': 0.0},
    ]


model = nn.Sequential(
    nn.Linear(10, 32),
    nn.LayerNorm(32),
    nn.Linear(32, 2),
)
optimizer = optim.SGD(split_weight_decay_params(model), lr=1e-3)

for i, group in enumerate(optimizer.param_groups):
    print(
        f'Parameter group {i}: '
        f'weight decay = {group["weight_decay"]}, '
        f'number of params = {len(group["params"])}'
    )

因此,优化器并不只能用同一套超参数更新整个模型。它内部维护的是一组一组的参数,每一组都可以有自己的学习率、权重衰减、动量等配置。根据需要,可以把模型参数拆成不同的组,再交给优化器统一管理。

2.6.5 优化器不只是公式,也有状态

如果使用最简单的 SGD,并且没有 momentum,那么参数更新只依赖当前梯度 \(g\)

\[ \theta \leftarrow \theta - \eta g \]

但是,很多优化器不仅依赖当前梯度,还会保存某些历史信息。例如带 momentum 的 SGD 会维护一个动量缓冲区:

\[ v_t = \mu v_{t-1} + g_t \]

Adam 和 AdamW 会维护梯度的一阶矩和二阶矩估计。

这些历史信息不是模型参数,但它们会影响后续更新。因此,优化器内部也有状态(state)

我们可以直接看一个优化器的 state_dict()

model = nn.Linear(3, 1)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

pprint(optimizer.state_dict(), sort_dicts=False)

刚创建优化器时,state 通常是空的,因为还没有执行过任何一步更新。

执行一次更新后,再看:

x = torch.randn(8, 3)
y = torch.randn(8, 1)

loss = F.mse_loss(model(x), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

state_dict = optimizer.state_dict()
pprint(state_dict, sort_dicts=False)

一个优化器的 state_dict 里通常有两部分:

  1. state:每个参数对应的优化器内部状态;
  2. param_groups:参数组配置,比如学习率、weight decay、betas 等。

对于 AdamW,我们通常会看到类似 stepexp_avgexp_avg_sq 这样的状态。这些就是 AdamW 更新时需要用到的历史信息。它们的具体含义会在后续介绍 AdamW 时展开。

这也解释了为什么恢复训练时,只保存模型参数还不够。如果只保存 model.state_dict(),模型权重可以恢复,但 SGD 的动量、Adam/AdamW 的一阶矩和二阶矩都会丢失。

因此,更完整的 checkpoint 通常会包含:

checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}
torch.save(checkpoint, 'checkpoint.pth')

对应地,加载时也要恢复两部分:

model = nn.Linear(3, 1)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

这样训练才能尽量从中断的位置继续,而不是只拿到同一组模型参数重新开始优化。

2.6.6 foreach 和 fused:同一个优化器的不同实现方式

有时候我们会在优化器里看到这样的参数:

optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-3,
    foreach=True,
)

或者:

optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-3,
    fused=True,
)

这些参数不是在改变 AdamW 的数学目标,而是在选择怎么执行参数更新

在 PyTorch 里,同一个优化器大致可以有三种实现路线:

  1. for-loop:最传统的方式,逐个参数张量更新;
  2. foreach:把一组张量打包,调用批量张量操作;
  3. fused:把多个更新操作融合到更少的 kernel 里执行。

最容易理解的是 for-loop。它像这样逐个参数处理:

for param in params:
    update(param)

这种方式简单、通用,但如果模型有很多参数张量,就会产生很多小操作,尤其在 GPU 上可能效率不高。

foreach 的思路是:不要一个张量一个张量地更新,而是把很多张量作为一个列表,一起交给底层实现处理。它通常比普通 for-loop 更快,尤其是在 GPU 上有很多参数张量时。但 foreach 也不是完全免费的,因为它经常需要保存中间的张量列表,所以峰值显存可能会更高。

fused 则更进一步。它希望把优化器更新里的多个操作融合起来,减少 kernel 的反复启动和中间读写。直观地说,foreach 更像是一次处理很多张量,而 fused 更像是把一次更新里的多个操作合并执行。

因此,在支持良好的 CUDA 场景下,fused=True 可能更快。但它对设备、数据类型和优化器实现的支持要求更高。

我们可以先把它们理解为三种执行方式:

  1. for-loop:最朴素,兼容性最好;
  2. foreach:通常更快,但可能多占显存;
  3. fused:更激进,可能最快,但支持范围更有限。

实际使用时,如果没有特别需求,通常可以先让 PyTorch 使用默认选择。只有在优化大模型训练性能、显存,或遇到兼容性问题时,才需要手动指定 foreachfused

提示

关于 PyTorch 的 foreachfused 在不同优化器里的支持情况,官方文档有更详细的说明和兼容性列表,可以参考torch.optim - Algorithms

下面这段代码只是演示参数如何传入。不同机器、不同 PyTorch 版本和不同设备上,是否支持 fused=True 可能不一样。

model = nn.Linear(10, 2)
optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-3,
    foreach=False,
    fused=False,
)

print(optimizer)

2.6.7 optimizer.step() 默认不会记录计算图

前面我们手动更新参数时,写了:

with torch.no_grad():
    param -= lr * param.grad

这是因为普通训练中,参数更新本身通常不需要被 Autograd 记录。

换句话说,我们一般只关心 loss 如何对参数求梯度,而不关心 optimizer.step() 这个更新过程本身如何再被求导。

PyTorch 的优化器也是这个默认逻辑。optimizer.step() 默认会在不记录梯度的上下文里更新参数。

这对普通训练是合理的,因为如果每一步参数更新都被记录进计算图,显存会迅速增长,训练也会变得复杂。但是,有些更高级的场景确实需要对优化过程求导。例如元学习,可微分优化,学习学习率,或者把若干步梯度更新当成计算图的一部分。

这时需要让优化器的更新过程也参与 Autograd。PyTorch 优化器里对应的参数叫 differentiable。设置为 True 后,优化器的 step 过程会被追踪,允许继续对更新后的参数求导。

例如:

model = nn.Linear(3, 1)
optimizer = optim.SGD(
    model.parameters(),
    lr=0.1,
    differentiable=True,
)

flag = optimizer.defaults['differentiable']
print('Is optimizer step differentiable?', flag)

不过,differentiable=True 不是常规训练需要打开的选项。它会让优化器 step 的计算也被追踪,通常会带来更多内存开销,也可能要求更谨慎地写代码。

因此,对于大多数训练场景,使用默认的 differentiable=False 即可;只有在需要对参数更新过程继续求导的特殊场景下,才考虑设置 differentiable=True

2.6.8 一个完整的优化步骤

现在可以把前面的内容连起来,写一个完整的优化步骤。

model = nn.Sequential(
    nn.Linear(10, 32),
    nn.ReLU(),
    nn.Linear(32, 1),
)
optimizer = optim.AdamW(
    [
        {'params': model[0].parameters(), 'lr': 1e-3},
        {'params': model[2].parameters(), 'lr': 1e-2},
    ],
    weight_decay=1e-2,
)

x = torch.randn(16, 10)
y = torch.randn(16, 1)

model.train()
pred = model(x)
loss = F.mse_loss(pred, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Loss:', loss.item())

这段代码背后发生了几件事:

  • model(x):执行前向传播,计算预测值 pred,并构建计算图;
  • loss.backward():根据当前的 loss 计算梯度,并把结果累加到模型参数的 .grad 属性中;
  • optimizer.step():根据参数的 .grad 和优化器内部状态,更新参数值;
  • optimizer.zero_grad():清空旧梯度,避免下一次 backward() 继续累加。

其中,优化器不仅保存了超参数,也可能保存了历史状态。参数组决定了不同参数如何被更新,foreachfused 决定了更新过程如何被高效执行,而 differentiable=True 则决定了更新过程本身是否进入计算图。

2.6.9 本章小结

本节从最简单的手动梯度下降出发,说明了 PyTorch 优化器的作用。backward() 负责计算梯度,把结果写到参数的 .grad 中;optimizer.step() 才真正根据这些梯度修改参数。

由于 PyTorch 默认会累加梯度,所以常规训练中需要在每次反向传播前调用 optimizer.zero_grad()。如果故意不清空梯度,也可以实现梯度累积。

优化器接收的不一定是一组统一参数,也可以是多个参数组。不同参数组可以设置不同学习率、weight decay 等超参数,这在微调和大模型训练中非常常见。

同时,优化器自己也有状态。像 AdamW 这样的优化器会保存历史梯度统计量,因此恢复训练时通常要同时保存 model.state_dict()optimizer.state_dict()

最后,我们还区分了优化器的几种执行实现。foreachfused 不改变优化算法的数学含义,而是改变更新过程的执行方式。普通训练时一般使用默认设置即可;只有在关心性能、显存或可微分优化时,才需要进一步控制这些选项。

下一节,我们就把从数据加载、模型定义、损失计算到优化器更新的完整训练循环写出来,看看它们是如何协同工作的。

二次使用