import torch
import torch.nn as nn
import torch.nn.functional as F
print('PyTorch version:', torch.__version__)Chapter 2.3 PyTorch 中的梯度记录与控制
在 2.1 节中,我们已经弄清楚了梯度的来源:前向传播负责把计算过程记录下来,反向传播再沿着这些记录回溯,Autograd 借助计算图把梯度一步步传回去。
不过真正写 PyTorch 代码时,很快会碰到另一个更具体的问题:这次计算到底有没有必要被记录?
训练阶段当然需要记录,因为后面还要反向传播。但如果是在验证、推理、特征提取,或者只是临时跑一遍模型看看输出,继续记账就不太划算了。它会保存中间结果、构造计算图、占用显存,还可能把一段原本只想做数值计算的代码意外纳入梯度链条。
因此这一节换一个角度:先不继续追问“梯度怎么算”,而是关心“哪些计算会进入 Autograd 的记录,哪些不会”。PyTorch 为此提供了几个直接的控制开关:torch.no_grad()、torch.enable_grad(),以及面向推理优化的 torch.inference_mode()。这些模式不会改变前向计算的数值,却会决定是否构建计算图、是否允许反传,以及这段代码会消耗多少额外内存和运行开销。
这也体现了 PyTorch 一个很重要的分工:算子负责完成数值计算,Autograd 负责决定这次计算要不要留下可求导记录。下面就从最常见的 no_grad() 开始,看这些梯度模式分别做了什么。
2.3.1 torch.no_grad():暂停记账
默认情况下,只要参与运算的张量带有 requires_grad=True,PyTorch 就会为相关计算建立计算图。换句话说,只要计算发生在可求导环境里,Autograd 就会自动把这段过程记录下来。但很多时候,这份记录并不是我们想要的。
比如验证模型性能时,我们通常不会反向传播,也就不需要梯度。推理阶段也是类似,我们关心的是模型给出的结果,而不是这个结果之后还能不能继续反传。如果这时仍然让 Autograd 构建计算图,就会白白增加内存占用和运行开销。
因此,PyTorch 提供了 torch.no_grad() 这个上下文管理器(也可以作为装饰器使用),用来明确告诉 Autograd:这个代码块里的计算不需要记录。
先看一个直接的对比。默认模式下:
model = nn.Linear(6, 4)
x = torch.randn(10, 6)
y = torch.randn(10, 4)
y_pred = model(x)
print('`y_pred.requires_grad` before `no_grad()`:', y_pred.requires_grad)这里会输出 True,原因是模型参数默认需要梯度,因此前向结果也会被纳入计算图。
现在把同样的前向过程放进 no_grad():
with torch.no_grad():
y_pred = model(x)
print('`y_pred.requires_grad` inside `no_grad()`:', y_pred.requires_grad)这次得到的就是 False。
需要注意的是,no_grad() 并不会影响前向传播本身,模型照样计算输出;变化只在于这些新结果不会被 Autograd 跟踪。一个张量一旦没有被跟踪,后续依赖它得到的结果通常也不会进入计算图。如果此时对这类结果调用 backward(),PyTorch 会报错,因为它根本没有可回溯的计算图。
loss = F.mse_loss(y_pred, y)
try:
loss.backward()
except RuntimeError as err:
print('RuntimeError:', err)这里的 loss 虽然是在 no_grad() 外面算的,但它依赖的 y_pred 已经不带梯度记录,另一个输入 y 也没有请求梯度。因此 loss 本身也不会属于计算图,对它调用 backward() 就会触发错误。
有一个常见误解是:no_grad() 会把张量自己的 requires_grad 改成 False。实际上并不是这样。no_grad() 控制的是当前代码块里的计算是否被记录,并不改写已有张量的属性。一个在外部创建、并且 requires_grad=True 的张量,进入 no_grad() 后仍然保留这个属性,只是基于它产生的新计算不会被记录。
x = torch.randn(10, 6, requires_grad=True)
with torch.no_grad():
print('`x.requires_grad` inside `no_grad()`:', x.requires_grad)
y_pred = model(x)
print('`y_pred.requires_grad` inside `no_grad()`:', y_pred.requires_grad)所以,no_grad() 不是取消张量本身的“可导资格”,而是让当前上下文里产生的新计算不被登记。可以把 requires_grad 理解为一种能力声明:这个张量有资格进入梯度系统;而 no_grad() 更像行为开关:这一段计算暂时不要记录。两者并不冲突。
另外,如果在 no_grad() 中创建了一个新张量,之后又希望它参与自动微分,仍然可以通过 requires_grad_() 把它打开:
with torch.no_grad():
x = torch.randn(10, 6)
print('`x.requires_grad` inside `no_grad()`:', x.requires_grad)
x.requires_grad_()
print('`x.requires_grad` after `requires_grad_()`:', x.requires_grad)换句话说,no_grad() 做的是临时暂停记录,而不是永久切断张量进入梯度系统的可能性。为了之后能够恢复梯度模式,PyTorch 内部仍会维护一些状态和检查机制,这也意味着它并不是完全零开销。这个点会和后面的 inference_mode() 形成鲜明对比:inference_mode() 不只是“不记录”,还会关闭更多 Autograd 相关机制,使得其中创建的张量不能再靠 requires_grad_() 回到梯度跟踪里。
从更底层的角度看,PyTorch 中“算出数值”和“留下求导记录”是两件事。no_grad() 只影响后者,不影响前者。所以它才会频繁出现在模型验证、推理部署和参数更新这类场景中。
接下来就会自然引出另一个问题:既然梯度记录可以关掉,那能不能只在局部重新打开?如果推理流程中的某一小段又临时需要梯度,该怎么办?这就要用到 torch.enable_grad()。
2.3.2 torch.enable_grad():重新开始记账
上一节看到,no_grad() 可以让 Autograd 暂时停止记录。那如果外层已经处在 no_grad() 里,我们能不能只给其中一小段计算重新打开梯度记录?
答案是可以,这正是 enable_grad() 的用途。
反过来也可以:外层保持梯度开启,内层再用 no_grad() 临时关闭。这些上下文都可以嵌套。只是默认模式本来就开启梯度记录,所以在最外层额外写一个 enable_grad() 通常没有意义。
还是先来看一个简单的例子:
x = torch.randn(10, 6, requires_grad=True)
with torch.no_grad():
y = x * 3 # Does not record computation graph
print('`y.requires_grad` in `no_grad()`:', y.requires_grad)
with torch.enable_grad():
z = x * 4 # Enables gradient tracking
print('`z.requires_grad` in `enable_grad()`:', z.requires_grad)
# Only z will have gradients tracked
z.backward(gradient=torch.ones_like(z))这个例子的关键在于:外层 no_grad() 先关闭记录,内层 enable_grad() 又把记录能力临时恢复。离开内层之后,外层的 no_grad() 依然生效,后面的计算又会回到不被跟踪的状态。这说明梯度模式是按照栈来管理的:进入上下文时压入一个模式,退出时恢复到之前的模式。
那么,这有什么意义呢?
很多工程代码会复用同一条执行路径。比如推理阶段的大多数前向计算不需要梯度,但某个中间步骤要做敏感性分析;又或者调试时只想临时算一次梯度。如果没有 enable_grad(),我们可能要拆开流程,或者在外层反复切换状态。有了它,就可以只在需要的位置打开记录,而不扰动整体推理逻辑。
还有一个更通用的接口叫 torch.set_grad_enabled()。它接收一个布尔值,直接决定当前是否启用梯度记录。no_grad() 和 enable_grad() 可以看作它的两个常用特例。
x = torch.randn(10, 6)
is_training = False
with torch.set_grad_enabled(is_training):
y_pred = model(x)当 is_training=True 时,它相当于 enable_grad();当 is_training=False 时,它相当于 no_grad()。这种写法很适合把训练和评估流程放在同一段逻辑中管理。
到这里,我们已经见过两个常用的梯度控制上下文:no_grad() 用来关闭记录,enable_grad() 用来恢复记录,并且它们可以嵌套形成灵活的栈式状态。接下来再看一个更偏向推理优化的上下文:torch.inference_mode()。它比 no_grad() 更进一步,换来的也是更强的限制。
2.3.3 torch.inference_mode():干脆以后都别记账了
前两节已经给了我们一套很灵活的控制方式:
no_grad()可以关闭梯度记录;enable_grad()可以局部恢复梯度记录;set_grad_enabled()是一个更通用的接口,可以直接设置当前的梯度模式;- 梯度模式是可嵌套、可恢复的。
表面上看,这已经够用了。那 PyTorch 为什么还要单独提供 inference_mode()?
原因在于一个更强的前提:如果我们不仅知道当前不需要梯度,而且确定这段计算之后也绝不会参与反向传播,那么框架是否可以进一步省掉所有梯度相关的维护成本?
这就是 inference_mode() 的设计动机1。
在 no_grad() 里,PyTorch 虽然不建计算图,但仍会维护版本计数器(version counter)、视图追踪(view tracking),以及一些保证梯度正确性的内部检查。训练时这些机制很重要,因为它们能防止原地操作破坏图结构,也能避免共享内存视图带来的梯度错误。但在纯推理阶段,这些维护就成了额外负担。既然这段计算的结果永远不会用于求导,框架就可以跳过更多与 Autograd 相关的检查和追踪,从而得到更激进的内存与性能优化。因此,inference_mode() 通常会比 no_grad() 更快,也更省显存。
但是,它是不可逆的。
前面已经看到,在 no_grad() 中创建的张量,之后仍然可以重新打开梯度:
with torch.no_grad():
x = torch.randn(10, 6)
x.requires_grad_() # we can still enable gradients for x
print('`x.requires_grad` after `requires_grad_`:', x.requires_grad)但是如果张量是在 inference_mode() 中创建的,再尝试设置 requires_grad=True 就会直接报错:
with torch.inference_mode():
x = torch.randn(10, 6)
try:
x.requires_grad_()
except RuntimeError as err:
print('RuntimeError:', err)原因是 inference_mode() 并不是简单暂停记录,而是会创建一种特殊的推理张量(inference tensor)。这类张量被标记为不会进入自动微分系统。即使之后重新开启梯度模式,它们也不会被加入计算图。因此,no_grad() 是临时关闭,inference_mode() 更接近永久关闭。只有当我们确信某段代码只服务于推理时,才适合使用它。
2.3.4 不同梯度模式下的行为对比
到这里,三种梯度语义已经比较清楚了:默认模式、no_grad() 模式和 inference_mode() 模式。它们表达的是不同强度的承诺,也对应不同的灵活性和性能取舍。
默认模式下,Autograd 必须假设当前任何计算都有可能被拿去反向传播。所以它会:
- 构建完整计算图
- 保存反向传播所需的中间结果
- 维护版本计数器和视图一致性检查
这是最灵活的模式,但代价也最高,通常用于训练阶段的前向传播。
进入 no_grad() 时,我们表达的是一个阶段性声明:这段计算现在不参与反向传播。
在这个前提下,Autograd 可以做一些优化:
- 不再构建计算图
- 不再保存中间结果
- 但仍然保留 Autograd 的内部一致性机制
- 退出该上下文后可以恢复正常梯度模式
这种模式属于临时关闭。它保留了之后恢复梯度的灵活性,同时能明显减少开销,常用于验证和模型评估。
而 inference_mode() 给出的承诺更强:这段计算永远不会参与梯度计算。基于这个前提,Autograd 可以进一步优化:
- 不构建计算图
- 跳过与梯度相关的版本检查与视图追踪
- 在该模式下创建的张量无法再重新加入自动微分系统
这是不可逆意味更强的关闭方式。它的优化最充分,限制也最多,适合纯推理、模型评估和数据处理等不会再求导的场景。
脚注
inference_mode()是在 PyTorch 1.9 版本引入的,专门针对推理阶段的性能优化。关于它的具体实现,可以参考 RFC-0011-InferenceMode。↩︎