Chapter 2.1 PyTorch 中的自动微分

作者

Brench

发布于

2026-05-10

修改于

2026-05-10

在 1.3 节中,我们曾把计算图理解成一条“责任链”:损失为什么会取到当前这个值,沿着链条往回查,就能看到每个参数承担了多少责任。这一节换到更工程化的视角:框架到底怎样自动搭起这条链,又是怎样在需要时把梯度算出来的?

把问题说得更直接一些:训练时我们需要梯度,可实际写下来的只是普通代码:加法、乘法、卷积、激活函数……这些操作在前向传播中依次执行,最后得到一个 loss。那梯度究竟从哪里来?框架难道真的会推导出一条巨大的符号公式吗?

当然不是。深度学习框架更像是在做两件事:

理解这套机制很重要。它不仅回答“梯度从哪里来”,也会解释后面经常遇到的现象:为什么梯度会累积?为什么中间变量默认没有 .grad?为什么某些操作会断开梯度链?以及为什么显存和计算之间总要做取舍。

import torch
import torch.autograd.functional as AF

print('PyTorch version:', torch.__version__)

2.1.1 计算图不是画出来的,是跑出来的

理解 PyTorch 自动微分时,最好的入口不是先背术语,而是先观察一个事实:你表面上只是在做前向计算,计算图却会随着代码运行自动生成。

假设有这样一个简单函数:

\[ z = \sin(x \cdot y) \]

它可以拆成两个基础运算:

  1. 计算向量内积:\(q = x \cdot y\)
  2. 计算正弦函数:\(z = \sin(q)\)

接着告诉 PyTorch:后续我们希望能得到 zxy 的梯度。

x = torch.arange(1.0, 5.0, requires_grad=True)
y = torch.arange(5.0, 9.0, requires_grad=True)

这里的 requires_grad=True 可以看成一种声明:这些变量之后需要被纳入“追责”范围。只要某个结果由它们参与计算得到,这个结果就会自动具备可导属性,并在背后记录自己是由谁算出来、又依赖了哪些对象。

现在执行两步普通前向计算:先做点积,再取正弦。

q = x.dot(y)
z = q.sin()
print('z.requires_grad:', z.requires_grad)

到这里,表面上看到的仍然只是数值计算,但 PyTorch 已经完成了两件事:

  1. z 会自动成为需要梯度的结果,因为它依赖需要梯度的 xy
  2. qz 的生成关系会被记录下来:z 来自 sinq 来自 dot,而 q 又依赖 xy

先不用急着画计算图。我们先看一个更直观的现象:在调用反向传播之前,梯度不会自动出现在张量上。

print('x.grad:', x.grad)
print('y.grad:', y.grad)

这里显示的是 None,不是 0。原因很简单:梯度是反向回溯之后才产生的结果。只有当你明确启动回溯,比如调用 backward(),PyTorch 才会沿着刚才记录的依赖关系计算梯度,并把结果写回叶子节点。如果没有发起回溯,它就不会计算梯度,自然也不会填入数值。

接下来就从 z 开始做反向传播,看看 .grad 如何出现,以及它是否和手算结果一致。

2.1.2 backward 到底做了什么:从输出往回查账

上一小节里,我们只写了前向计算,但 PyTorch 已经把依赖关系记录了下来。现在真正要看的问题是:调用 backward() 时,框架具体会做什么?它算出来的梯度能不能和手算对上?

还是沿用同一个例子:

\[ q = x \cdot y, \quad z = \sin(q) \]

手动求导可以得到:

\[ \frac{\partial z}{\partial x} = \frac{\partial z}{\partial q} \cdot \frac{\partial q}{\partial x} = \cos(q) \cdot y \] \[ \frac{\partial z}{\partial y} = \frac{\partial z}{\partial q} \cdot \frac{\partial q}{\partial y} = \cos(q) \cdot x \]

现在让 PyTorch 来完成同样的事情。直接从输出 z 发起回溯:

z.backward()
print('x.grad:', x.grad)
print('y.grad:', y.grad)

此时 .grad 不再是 None,梯度已经写回到 xy 这两个叶子节点。直觉上,可以这样理解 backward()

  1. z 出发,默认令 \(\frac{\partial z}{\partial z} = 1\)
  2. 沿着前向传播时记录的依赖关系向前回溯;
  3. 每经过一个算子节点,就使用该算子的局部导数规则,把梯度传给更上游的节点。

可以把结果和手算公式对齐:

# pyright: reportArgumentType=false
assert torch.allclose(x.grad, y * x.dot(y).cos())
assert torch.allclose(y.grad, x * x.dot(y).cos())

到这里,自动微分的核心逻辑已经很清楚了。深度学习框架不需要推导一个庞大的全局导数表达式,只需要掌握每个局部操作的求导规则,再按照计算图把这些规则串起来。

再往里看一点,PyTorch 也会把这条回溯链的一部分暴露给我们。例如:

# pyright: reportOptionalMemberAccess=false
print('z.grad_fn:', z.grad_fn.name())
print('q.grad_fn:', q.grad_fn.name())
print('x.grad_fn:', x.grad_fn)
print('y.grad_fn:', y.grad_fn)

通常会看到类似 SinBackward0 这种带有 Backward 的名字。可以粗略理解为:

  • z 不是凭空出现的,而是由某个算子产生的结果,这里对应 sin
  • grad_fn 就是这个算子在反向传播阶段对应的梯度函数对象。

反向传播时,PyTorch 会从根节点开始,依次调用各个节点对应的导数算子,把梯度传到输入端为止。比如调用 z.backward() 时,会先经过 z 对应的 SinBackward0,得到 \(\frac{\partial z}{\partial q}\);再把梯度传给 q 对应的 DotBackward0,继续得到 \(\frac{\partial q}{\partial x}\)\(\frac{\partial q}{\partial y}\);最终合成 \(\frac{\partial z}{\partial x}\)\(\frac{\partial z}{\partial y}\)。像 xy 这样的叶子节点没有 grad_fn,因为它们就是计算图的起点,不需要再向前追溯。

更关键的是,grad_fn.next_functions 会记录它接下来要回溯到哪些上游依赖:

# pyright: reportOptionalMemberAccess=false
node_q = z.grad_fn.next_functions[0][0]
node_x = node_q.next_functions[0][0]
node_y = node_q.next_functions[1][0]
print('grad_fn of z.child -> q:', node_q.name())
print('grad_fn of q.child -> x:', node_x.name())
print('grad_fn of q.child -> y:', node_y.name())

这些信息描述了为了计算 z 的梯度,反向传播下一步应该访问哪些节点、沿哪些输入继续回溯。例如,SinBackward0 的输入是 q,而 q 来自 DotBackward0,所以它的 next_functions 会指向 DotBackward0。类似地,DotBackward0 的上游会连接到 xy。其中 AccumulateGrad 是一种特殊节点,每个需要梯度的叶子节点前面通常都有一个对应的 AccumulateGrad,负责把传回来的梯度累加到叶子节点的 .grad 属性里。这就是调用 backward() 之后 x.grady.grad 会出现的原因。

2.1.3 为什么非标量不能直接 backward?

前面的例子中,z 是标量,所以可以直接写 z.backward()。但很多人第一次把输出换成向量或矩阵时,都会遇到 PyTorch 一个看起来不太直观的限制:

x = torch.arange(1.0, 5.0, requires_grad=True)
y = torch.arange(5.0, 9.0, requires_grad=True)
Z = x.outer(y)
try:
    Z.backward()  # This will raise an error because z is not a scalar
except RuntimeError as err:
    print('RuntimeError:', err)

这不是 PyTorch 故意设限,而是因为输出不是标量时,反向传播的起点不再唯一。

对于标量 z,我们通常想要的是 \(\frac{\partial z}{\partial x}\)\(\frac{\partial z}{\partial y}\)。反向传播从输出开始,第一步可以自然设定 \(\frac{\partial z}{\partial z} = 1\)。这是没有歧义的,因为标量输出只有一个默认方向。

但如果输出变成向量或矩阵 Z,问题就变成了:我们到底想反传哪个方向?

  • 想要 Z 每个元素分别对 xy 的梯度吗?那会得到更高阶的张量。
  • 还是想要某个标量函数,例如 Z 的和、均值或加权和,对 xy 的梯度?

也就是说,面对非标量输出时,反向传播必须先明确一件事:梯度要从哪个“方向”传回来?

在数学上,这个“方向”就是一个与输出同形状的张量 v,表示从上游传下来的梯度:

\[ v = \frac{\partial L}{\partial Z} \]

然后 PyTorch 实际上计算的是向量-雅可比积(VJP):

\[ \frac{\partial L}{\partial x} = v^\top \left(\frac{\partial Z}{\partial x}\right) \]

标量输出时,v 可以自动取 1(等价于直接把 \(L\) 看成 \(Z\));非标量输出时,这个 v 就需要我们显式提供。

这里就有两种写法。

一种写法是显式传入 gradient,告诉 PyTorch 要沿哪个方向回传梯度:

x = torch.arange(1.0, 5.0, requires_grad=True)
y = torch.arange(5.0, 9.0, requires_grad=True)
Z = x.outer(y)
Z.backward(gradient=torch.ones_like(Z))
print('x.grad:', x.grad)
print('y.grad:', y.grad)

这里的 torch.ones_like(Z) 等价于告诉 PyTorch:我关心的是 \(L = \sum_{i,j} Z_{i,j}\),因为

\[ \frac{\partial L}{\partial Z_{i,j}} = 1 \]

因此传入全 1 的上游梯度,就等价于先对所有元素求和,再调用 backward

另一种写法是先把 Z 规约成标量,再对这个标量调用 backward()

x = torch.arange(1.0, 5.0, requires_grad=True)
y = torch.arange(5.0, 9.0, requires_grad=True)
Z = x.outer(y)
Z = Z.sum()  # Now Z is a scalar
Z.backward()
print('x.grad:', x.grad)
print('y.grad:', y.grad)

很多情况下,这两种方式是等价的。要么显式给出回传方向,要么先把输出变成标量,让 PyTorch 使用标量输出默认的反向传播起点。

2.1.4 高阶导数:让求导过程也变成计算的一部分

到目前为止,我们计算的都是一阶梯度:给定一个标量输出(或可以规约成标量的输出)\(L\),求 \(\nabla_x L\)\(\nabla_y L\)。但有些场景还需要更高阶的信息,比如二阶导数、Hessian 的某些方向、曲率,或者某些正则项。

关键在于:如果你想继续对“梯度”求导,那么“计算这个梯度的过程”本身也必须被记录为可微计算。这就是 create_graph=True 的作用。它不仅返回一阶导数的数值,还会把产生这个导数的过程构造成新的计算图。

这时可能会有疑问:为什么不直接用 backward()?原因是 backward() 更偏向训练流程:它会把梯度累积到叶子张量的 .grad 属性里,并且默认释放计算图来节省内存。而在高阶导数场景下,我们通常更希望:

  • 梯度以张量形式返回,方便继续参与后续计算;
  • 在需要时保留或构建计算图,方便继续求导。

因此,更常用的接口是 torch.autograd.grad

继续沿用前面的例子:\(z = \sin(x \cdot y)\)。先求一阶导数 \(\frac{dz}{dx}\)\(\frac{dz}{dy}\),再对这些导数继续求导,观察二阶导数 \(\frac{d^2 z}{dx^2}\)\(\frac{d^2 z}{dy^2}\) 的结果。

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(4.0, requires_grad=True)
z = torch.sin(x * y)

dzdx, dzdy = torch.autograd.grad(z, (x, y), create_graph=True)
print('dz/dx:', dzdx)
print('dz/dy:', dzdy)

这里最关键的是 create_graph=True。如果没有它,dz/dxdz/dy 会被当作普通数值结果,不再保留它们的来源,因此也无法继续对它们求导。输出中带有 grad_fn,说明这些导数本身仍然处在可求导图中。

计算高阶导数时,有时需要在同一张计算图上对不同变量连续做梯度计算。但 PyTorch 默认会在一次反向传播后释放图中仅供反传使用的中间信息,以节省内存。如果确实要在同一次前向结果上多次回溯,可以设置 retain_graph=True 保留计算图:

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(4.0, requires_grad=True)
z = torch.sin(x * y)

dzdx, dzdy = torch.autograd.grad(z, (x, y), create_graph=True)
print('dz/dx:', dzdx)
print('dz/dy:', dzdy)

(d2zdx2,) = torch.autograd.grad(dzdx, x, retain_graph=True)
(d2zdy2,) = torch.autograd.grad(dzdy, y)
print('d2z/dx2:', d2zdx2)
print('d2z/dy2:', d2zdy2)

不过在实际代码里,更常见的做法是重新执行一次前向传播,得到一张新的计算图。retain_graph=True 通常只在确实需要复用同一张图时使用,比如高阶导数实验或某些正则项计算。

2.1.5 VJP 和 JVP:反向模式与正向模式到底在算什么?

到目前为止,我们一直笼统地说“求梯度”。但严格来说,深度学习中的大多数函数并不是标量到标量,而是:

\[ f: \mathbb{R}^n \to \mathbb{R}^m \]

它的导数是一个雅可比矩阵(Jacobian):

\[ J = \frac{\partial f}{\partial x} \in \mathbb{R}^{m \times n} \]

真正麻烦的是,当 \(m\)\(n\) 都很大时,我们几乎不会显式构造完整的 \(J\)。框架实际更常计算的是 Jacobian 与某个向量的乘积,只是这个向量可能乘在左边,也可能乘在右边。

2.1.5.1 VJP:向量-雅可比积(反向模式)

给定上游梯度向量 \(v \in \mathbb{R}^m\),也就是可以理解为 \(\frac{\partial L}{\partial f}\) 的量,反向模式计算的是:

\[ v^\top J \in \mathbb{R}^n \]

这就是 VJP(vector-Jacobian product)

换成训练里的语言就很熟悉:

  • 有一个标量 loss\(L = \mathcal{L}(f(x))\)
  • 有一个上游梯度:\(v = \frac{\partial L}{\partial f}\)
  • 反向传播要得到:\(\frac{\partial L}{\partial x} = v^\top \frac{\partial f}{\partial x}\)

所以,平时调用 backward(),本质上就是在计算一个特殊形式的 VJP。

def vjp_func(x: torch.Tensor, y: torch.Tensor):
    return x.dot(y).sin()


x = torch.arange(1.0, 5.0)
y = torch.arange(5.0, 9.0)
out = AF.vjp(vjp_func, (x, y))
print('func(x,y):', out[0])
print('VJP output:', out[1])

2.1.5.2 JVP:雅可比-向量积(正向模式)

正向模式的方向相反:给定输入空间里的一个方向 \(u \in \mathbb{R}^n\),它计算:

\[ Ju \in \mathbb{R}^m \]

这就是 JVP(Jacobian-vector product)。直觉上,它回答的是:如果输入沿方向 \(u\) 发生一个很小的扰动,输出会朝哪个方向变化?这种形式常见于敏感性分析、隐式层、部分二阶方法,以及物理/科学计算场景。

def jvp_func(a: torch.Tensor, b: torch.Tensor):
    return a.dot(b).sin()


x = torch.arange(1.0, 5.0)
y = torch.arange(5.0, 9.0)
v_x = torch.full_like(x, 0.1)
v_y = torch.full_like(y, 0.2)
out = AF.jvp(jvp_func, (x, y), (v_x, v_y))
print('func(x,y):', out[0])
print('JVP output:', out[1])

2.1.5.3 为什么深度学习里更常见的是 VJP

这个问题不是在比较谁更高级,而是在看哪种模式更匹配问题规模。

  • 在深度学习训练中,\(n\) 往往是参数维度,可能达到百万或亿级,而 \(m\) 通常是标量或低维输出。
  • 我们真正需要的是 \(\nabla L \in \mathbb{R}^n\)

VJP 的成本大致和一次反向传播同量级,适合输入维度巨大而输出为标量或低维的情况。JVP 则更适合输入维度相对较小、但我们关心输出如何随输入方向变化的场景。因此一个常用判断是:输出标量或低维、输入维度很大时,反向模式(VJP)更合适;输入维度较小、输出维度很大时,正向模式(JVP)可能更占优势。

2.1.6 反向传播中的常见错误

x = torch.arange(1.0, 5.0, requires_grad=True)
y = torch.arange(5.0, 9.0, requires_grad=True)

1. 重复调用 backward()

在同一张计算图上重复调用 backward() 通常会报错。因为第一次反向传播结束后,PyTorch 会释放图里只为反传服务的中间值,以节省显存。第二次再沿同一张图回溯时,所需信息已经被清掉了。如果确实需要多次计算梯度,可以在第一次调用时设置 retain_graph=True

z = x.dot(y).sin()
z.backward()
try:
    z.backward()  # This will raise an error because gradients are already computed
except RuntimeError as err:
    print('RuntimeError:', err)
z = x.dot(y).sin()
z.backward(retain_graph=True)
z.backward()  # This works because we retained the graph

2. 尝试访问中间节点的梯度

默认只有叶子节点,也就是最初创建的变量,会保存梯度。中间节点的梯度不会自动存储,因为如果每个中间变量都保留梯度,显存开销会非常大,而训练通常真正需要的是参数梯度,不是所有中间量的梯度。因此访问中间节点的 .grad 往往会得到 None,并触发 UserWarning。如果确实需要查看某个中间节点的梯度,可以对它调用 q.retain_grad()

import warnings

q = x.dot(y)
z = q.sin()
z.backward()

with warnings.catch_warnings(record=True) as warns:
    print('q.grad:', q.grad)
    if len(warns) > 0:
        for warn in warns:
            print('UserWarning:', warn.message)
q = x.dot(y)
q.retain_grad()
z = q.sin()
z.backward()
print('q.grad after `retain_grad`:', q.grad)  # Now q.grad is available

3. 使用原地操作

PyTorch 中像 x.add_(1)x.relu_() 这种带下划线的操作表示原地修改张量:不创建新张量,而是直接改写 x 自己的内存。这样写看起来方便,但反向传播经常需要用到前向传播保存下来的中间值。如果这些值在前向之后被原地改掉,反向传播就可能丢失求梯度所需的信息。因此,在涉及反向传播的代码里,应尽量避免原地操作,或者确认它们不会覆盖反传需要的中间变量。

z = x.dot(y)
try:
    x.relu_()
except RuntimeError as err:
    print('RuntimeError:', err)
z = x.dot(y)
x = x.relu()
z.backward()

二次使用