import torch
x = torch.randn(4, 3)
weight = torch.randn(2, 3, requires_grad=True)
bias = torch.randn(2, requires_grad=True)
y = torch.addmm(bias, x, weight.T)Chapter 2.5 PyTorch 中的 nn.Module:组织模型、参数与状态
前面几节已经说明了 PyTorch 如何记录计算图、如何反向传播梯度,以及训练数据如何通过 Dataset 和 DataLoader 组织成 mini-batch。
但到目前为止,还有一个问题没有处理:模型本身应该怎么组织?
最简单的线性模型可以直接写成张量运算:
这段代码可以工作。weight 和 bias 会参与计算图,loss.backward() 之后也可以得到梯度。
但如果模型稍微复杂一点,我们马上会遇到很多工程问题:
- 哪些张量是需要优化的参数?
- 怎么把许多网络层组织成一个模型?
- 怎么把整个模型搬到 GPU?
- 怎么保存和加载模型状态?
- 怎么区分训练模式和评估模式?
如果所有东西都只是散落在外面的张量,这些问题都要手动管理。nn.Module 的作用,就是把这些职责收拢到一个对象里。它不仅是一个可以前向计算的对象,也是 PyTorch 用来组织计算、参数、缓冲区和子模块的基本容器。
本节从最小的线性层开始,逐步说明 nn.Module 背后的几个核心概念。
from pprint import pprint
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
print('PyTorch version:', torch.__version__)2.5.1 为什么需要 nn.Module
我们先回到刚才的线性变换:
\[ y = xW^\top + b \]
如果直接用张量写,代码大概是这样:
in_features = 3
out_features = 2
weight = torch.randn(out_features, in_features, requires_grad=True)
bias = torch.randn(out_features, requires_grad=True)
x = torch.randn(4, in_features)
y = torch.addmm(bias, x, weight.T)
print(y.shape)这段代码里,weight 和 bias 是模型参数。但是 PyTorch 并不知道它们属于同一个模型,因为它们只是两个普通变量。这会带来一个问题:如果以后参数变多了,我们需要手动收集它们:
parameters = [weight, bias]
pprint(parameters)对于一个只有两个参数的线性模型,这还可以接受。但如果模型有几十层、上百个参数张量,手动维护列表很容易出错。
我们真正想要的是:
- 把参数放进一个容器里;
- 让这个容器知道哪些张量是参数;
- 让优化器可以自动拿到这些参数;
- 让这个容器可以被保存、加载、移动到 GPU。
这正是 nn.Module 要解决的问题。
我们先手写一个最小版本的线性层:
class SimpleLinear(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.randn(out_features))
def forward(self, x: Tensor) -> Tensor:
return torch.addmm(self.bias, x, self.weight.T)现在,weight 和 bias 不再是散落在外面的变量,而是属于 SimpleLinear 这个模块:
linear = SimpleLinear(3, 2)
for name, param in linear.named_parameters():
print(f'{name}: {param.shape}')这里最重要的变化是:只要把 nn.Parameter 赋值给 nn.Module 的属性,PyTorch 就会自动把它注册成该模块的参数。获取模型参数时,只需要调用:
params = list(linear.named_parameters())
pprint(params)有了 parameters(),就不需要手动维护参数列表。Module 会把模型内部所有参数递归地找出来。这是 nn.Module 的第一层含义:它是参数的组织方式。
2.5.2 forward():Module 负责组织计算
一个 nn.Module 通常会实现 forward() 方法,用来描述输入如何变成输出。例如前面的 SimpleLinear:
class SimpleLinear(nn.Module):
...
def forward(self, x: Tensor) -> Tensor:
return torch.addmm(self.bias, x, self.weight.T)使用时,我们通常写:
x = torch.randn(4, 3)
y = linear(x)
print(y.shape)注意,我们调用的是:
linear(x)而不是:
linear.forward(x)这是因为 nn.Module 的 __call__ 方法会在内部调用 forward(),同时还会处理 hooks、参数检查等额外逻辑。平时写模型时,我们应该调用 module(input),而不是直接调用 module.forward(input)。
从这个角度看,Module 不只是参数容器,它也描述了一段可复用的计算:
Module = parameters + forward computation
不过,这里还有一个容易混淆的问题:既然 nn.Module 里有计算,那 nn.functional 又是什么?
2.5.3 nn.Module 和 nn.functional 的关系
我们在 PyTorch 里经常会看到两种写法。
第一种是模块写法:
linear = nn.Linear(3, 2)
y1 = linear(x)
print(y1.shape)第二种是函数式写法:
weight = torch.randn(2, 3)
bias = torch.randn(2)
y2 = F.linear(x, weight, bias)
print(y2.shape)它们都能完成线性变换,但语义不一样。
nn.Linear 是一个 nn.Module。它内部持有自己的 weight 和 bias,并且这些参数会被自动注册:
for name, param in linear.named_parameters():
print(f'{name}: {param}')而 F.linear 是一个函数。它不会保存参数,也不会注册任何状态。你必须显式把 weight 和 bias 传进去:
F.linear(input, weight, bias)所以,二者的关系可以简单理解为:
nn.Module = 带状态的层或模型
nn.functional = 不保存状态的函数
很多 nn.Module 的 forward() 内部,本质上就是调用对应的 nn.functional 函数。例如,nn.Linear 内部大致可以理解成:
def forward(self, input: Tensor) -> Tensor:
return F.linear(input, self.weight, self.bias)这也是为什么我们在自定义模块时,经常会混合使用两者:
class SimpleMLP(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x这里的 self.fc1 和 self.fc2 有可学习参数,所以适合写成 nn.Module;而 relu 没有需要保存的参数,直接用 F.relu 就很自然。
也可以写成:
self.relu = nn.ReLU()这同样正确。区别主要在于:如果一个操作没有状态,用 functional 写法更轻;如果希望它出现在模块结构里,或者它本身有训练/评估行为,用 nn.Module 更清晰。
例如,Dropout 虽然没有可学习参数,但它在训练和评估时行为不同,所以通常写成:
self.dropout = nn.Dropout(p=0.5)这样它就能跟随整个模型一起切换 train() 和 eval() 模式。
2.5.4 Parameter:需要被优化器更新的张量
接下来具体看 Parameter。
普通张量即使设置了 requires_grad=True,赋值给 Module 后也不会自动成为模型参数:
class BadLinear(nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(2, 3, requires_grad=True)
bad = BadLinear()
for name, param in bad.named_parameters():
print(f'{name}: {param}')这里 self.weight 确实是一个需要梯度的张量,但它没有被注册成 Module 的参数。因此 parameters() 不会返回它,优化器也不会自动更新它。
如果希望一个张量成为模型参数,需要使用 nn.Parameter:
class GoodLinear(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(2, 3))
good = GoodLinear()
for name, param in good.named_parameters():
print(f'{name}: {param}')nn.Parameter 可以理解成一种特殊的 Tensor。它的特殊之处不在于数学计算,而在于:当它被赋值给 nn.Module 的属性时,会被自动注册为模型参数。
所以,Parameter 的含义是:
这是模型的一部分,并且通常需要被优化器更新。
例如线性层的权重、卷积核、词嵌入表,都是典型的 parameter。
除了直接赋值,我们也可以用 register_parameter() 显式注册参数:
class ExplicitLinear(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
weight = nn.Parameter(torch.randn(out_features, in_features))
bias = nn.Parameter(torch.randn(out_features))
self.register_parameter('weight', weight)
self.register_parameter('bias', bias)
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)大多数时候,直接写:
self.weight = nn.Parameter(...)就够了。register_parameter() 更常见于参数名需要动态生成,或者需要显式控制某个名字是否注册参数的情况。例如,有些模块可以选择是否使用 bias:
class OptionalBiasLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.randn(out_features))
else:
self.register_parameter('bias', None)
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)这样,即使没有 bias,这个模块的结构也很明确:它有一个名为 bias 的位置,只是当前没有参数。
2.5.5 Buffer:属于模型状态,但不是可学习参数
并不是模型里的所有张量都应该被优化器更新。
例如,BatchNorm 里有 running mean 和 running variance。它们会随着训练数据更新,用于评估阶段的归一化,但它们不是通过梯度下降学出来的参数。再比如,位置编码中预先计算好的 sinusoidal table,也可能希望跟随模型一起保存、加载和移动设备,但不希望优化器更新它。这类张量就适合放在 buffer 里。
Buffer 可以理解成:
属于模型状态、需要跟着模型走,但不是可学习参数的张量。
看一个简单例子。假设有一个标准化模块,它用固定的均值和标准差对输入做归一化:
class Normalize(nn.Module):
def __init__(self, mean: Tensor, std: Tensor):
super().__init__()
self.register_buffer('mean', mean)
self.register_buffer('std', std)
def forward(self, x: Tensor) -> Tensor:
return (x - self.mean) / self.std这里的 mean 和 std 不应该出现在 parameters() 里:
mean = torch.tensor([0.5, 0.5, 0.5])
std = torch.tensor([0.2, 0.2, 0.2])
normalize = Normalize(mean, std)
print('Parameters:')
for name, param in normalize.named_parameters():
print(f'{name}: {param}')
print('Buffers:')
for name, buffer in normalize.named_buffers():
print(f'{name}: {buffer}')但它们会出现在 state_dict() 里:
pprint(normalize.state_dict())这说明它们会随着模型状态一起保存。
Buffer 还有一个重要特点:当我们调用 model.to(device) 时,buffer 会和 parameter 一起被移动到对应设备。比如,如果我们只是把它们保存成普通属性:
self.mean = mean
self.std = std那么它们既不会出现在 state_dict() 里,也不会作为模型状态被统一管理。用 register_buffer() 注册后,它们会被视为模型状态的一部分,跟随模型一起移动到对应设备。
print('Buffers before moving to device:', normalize.mean.device)
device = torch.accelerator.current_accelerator(check_available=True)
normalize.to(device)
print('Buffers after moving to device:', normalize.mean.device)判断一个张量要不要注册成 buffer,可以看三个问题:
- 它是不是模型的一部分?
- 它要不要保存和加载?
- 它要不要跟随
model.to(device)移动?
如果答案都是肯定的,但它又不是优化器要更新的参数,那么它很可能应该是 buffer。
register_buffer() 还有一个参数叫 persistent。默认情况下,buffer 是 persistent 的,也就是会保存到 state_dict() 里:
self.register_buffer('mean', mean, persistent=True)如果设置成 persistent=False,这个 buffer 仍然会跟随设备移动,也能通过 buffers() 找到,但不会保存到 state_dict() 里。这适合一些可以重新生成的缓存,例如某些中间 mask、临时查表结果等。
class MaskCache(nn.Module):
def __init__(self, max_len: int):
super().__init__()
mask = torch.tril(torch.ones(max_len, max_len, dtype=torch.bool))
self.register_buffer('causal_mask', mask, persistent=False)
cache = MaskCache(max_len=4)
print('Buffers:')
for name, buffer in cache.named_buffers():
print(f'{name}: {buffer}')
print('State dict keys:', cache.state_dict())总结一下,Parameter 和 Buffer 的区别是:
| 类型 | 是否是模型状态 | 是否被优化器更新 | 是否进入 state_dict | 是否跟随 model.to(device) |
|---|---|---|---|---|
| Parameter | 是 | 通常是 | 是 | 是 |
| Buffer | 是 | 否 | 默认是 | 是 |
2.5.6 子模块:Module 可以嵌套 Module
我们知道,神经网络通常不是一个单独的层,而是很多层组合起来的结构。在 nn.Module 里,一个模块可以包含另一个模块。
比如前面的 MLP:
class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 8)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(8, 2)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = SimpleMLP()
print(model)这里 fc1、relu 和 fc2 都是 model 的子模块。只要把一个 nn.Module 赋值给另一个 nn.Module 的属性,它也会被自动注册。
因此,model.parameters() 会递归地找到所有子模块里的参数:
for name, param in model.named_parameters():
print(f'{name}: {param.size()}')参数名里的点号表示模块层级。例如:
fc1.weight
fc1.bias
fc2.weight
fc2.bias
说明 weight 和 bias 属于子模块 fc1 或 fc2。
PyTorch 也提供了几组常用方法来遍历模块结构。
children() 只返回当前模块的直接子模块:
for child in model.children():
print(child)named_children() 会同时返回名字:
for name, child in model.named_children():
print(f'{name}: {child}')modules() 会递归返回当前模块和所有子模块:
for module in model.modules():
print(type(module).__name__)named_modules() 会递归返回模块名和模块对象:
for name, module in model.named_modules():
print(f'{repr(name)} -> {type(module).__name__}')这里第一个名字是空字符串,表示模型本身。
这些方法在调试模型结构、冻结部分层、替换某些子模块时都很常用。
例如,我们可以找到所有线性层:
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
print('Linear layer:', name)但是,如果我们把子模块放进普通 Python list 里,PyTorch 不会自动注册 list 里面的模块。
class BadStack(nn.Module):
def __init__(self):
super().__init__()
self.layers = [nn.Linear(3, 3), nn.Linear(3, 3)]
bad_stack = BadStack()
for name, param in bad_stack.named_parameters():
print(f'{name}: {param.size()}')如果想保存一组子模块,应该使用 nn.ModuleList:
class GoodStack(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList(
[
nn.Linear(3, 3),
nn.Linear(3, 3),
]
)
def forward(self, x: Tensor) -> Tensor:
for layer in self.layers:
x = layer(x)
return x
good_stack = GoodStack()
for name, param in good_stack.named_parameters():
print(f'{name}: {param.size()}')对于 dict 也是:
class BadDict(nn.Module):
def __init__(self):
super().__init__()
self.layers = {'layer1': nn.Linear(3, 3), 'layer2': nn.Linear(3, 3)}
def forward(self, x: Tensor) -> Tensor:
for layer in self.layers.keys():
x = self.layers[layer](x)
return x
bad_dict = BadDict()
for name, param in bad_dict.named_parameters():
print(f'{name}: {param.size()}')如果想保存一组命名的子模块,应该使用 nn.ModuleDict:
class GoodDict(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleDict(
{
'layer1': nn.Linear(3, 3),
'layer2': nn.Linear(3, 3),
}
)
def forward(self, x: Tensor) -> Tensor:
for layer in self.layers.keys():
x = self.layers[layer](x)
return x
good_dict = GoodDict()
for name, param in good_dict.named_parameters():
print(f'{name}: {param.size()}')如果模块之间是严格顺序执行,也可以使用 nn.Sequential:
sequential_model = nn.Sequential(
nn.Linear(3, 8),
nn.ReLU(),
nn.Linear(8, 2),
)
print(sequential_model)ModuleList 和 ModuleDict 更像是注册过的模块列表和字典,forward 怎么写由我们决定;Sequential 则直接定义了一条顺序计算链。
2.5.7 state_dict:模型状态的字典
每次我们训练完模型,我们要把训练好的模型保存下来。保存模型时,我们真正想保存的通常不是整个 Python 对象,而是模型里的状态。
在 PyTorch 中,这个状态由 state_dict() 表示:
state = model.state_dict()
for key, value in state.items():
print(f'{key}: {value.size()}')state_dict 是一个从名字到张量的字典。它包含所有的 parameter 和 persistent buffer。比如我们前面的 SimpleMLP,只有线性层参数,没有 buffer,所以它的 state_dict 主要是 fc1 和 fc2 的 weight 和 bias。
如果模块里有 buffer,它也会保存进去:
state = normalize.state_dict()
for key, value in state.items():
print(f'{key}: {value.size()}')这也是为什么前面说,buffer 虽然不是参数,但仍然是模型状态。
通常保存模型参数可以这样写:
torch.save(model.state_dict(), 'model.pt')加载时,先重新创建同样结构的模型,再调用 load_state_dict():
model = SimpleMLP()
state_dict = torch.load('model.pt')
flag = model.load_state_dict(state_dict)
print(flag)这里有一个很重要的思想:
模型结构由 Python 代码定义,模型状态由 state_dict 保存。
state_dict 只保存权重、bias、buffer 等张量,不保存 forward() 的 Python 逻辑。因此加载参数前,我们需要先创建一个结构匹配的模型对象,并将其实例化。如果结构不匹配,例如网络层的名字不同、参数形状不同,load_state_dict() 就会报错或者返回 IncompatibleKeys 对象。这个对象通常会告诉我们 missing keys 或 unexpected keys,也就是当前模型需要但文件里没有的参数,或者文件里有但当前模型用不到的参数。
在实际模型训练中,state_dict 非常重要,因为它让模型保存变得更加清晰:保存的是状态,而不是整个运行环境。
2.5.8 train() 和 eval():切换模块的行为
前面 2.2 节里,我们讨论过 torch.no_grad() 和 torch.inference_mode()。它们控制的是:是否记录计算图。但是,model.train() 和 model.eval() 控制的是另一件事:模块处于训练模式还是评估模式。这两个概念很容易混在一起,但它们不是一回事。
我们先看一个例子:
dropout = nn.Dropout(p=0.5)
x = torch.ones(5)
dropout.train()
print('Train mode:', dropout(x))
dropout.eval()
print('Eval mode:', dropout(x))在训练模式下,Dropout 会随机丢弃一部分元素;在评估模式下,Dropout 不再随机丢弃,而是直接返回输入。这说明 train() 和 eval() 会影响某些模块的 forward 行为。
最常见受影响的模块是:
- Dropout:训练时随机丢弃,评估时关闭随机丢弃;
- BatchNorm:训练时使用当前 batch 统计量并更新 running statistics,评估时使用保存的 running statistics。
我们可以看一下 training 属性:
model = SimpleMLP()
print(f'Initial training mode: {model.training}')
model.eval()
print(f'After calling eval(): {model.training}')
model.train()
print(f'After calling train(): {model.training}')model.train() 会把模型及其所有子模块都设置为训练模式;model.eval() 会递归地把它们设置为评估模式。它本质上等价于 model.train(False)。
但是,eval() 不会关闭自动微分。也就是说,下面这段代码虽然处在 eval 模式,但如果没有 no_grad(),PyTorch 仍然会记录计算图:
model.eval()
x = torch.randn(4, 3, requires_grad=True)
y = model(x)
print(f'y.requires_grad: {y.requires_grad}')因此,验证或推理时通常需要同时写:
model.eval()
with torch.no_grad():
y_pred = model(x)或者在纯推理场景中写:
model.eval()
with torch.inference_mode():
y_pred = model(x)所以,一个简单的区分方式是:
train()/eval()控制模块的行为,例如 Dropout 和 BatchNorm 在训练和评估时的不同计算逻辑;no_grad()/inference_mode()控制 Autograd 是否记录计算图,例如在评估阶段我们通常不需要梯度。
所以,eval() 只是告诉模块,现在请使用评估阶段的行为。至于是否记录梯度,还要由 no_grad() 或 inference_mode() 来控制。这一点很重要。
请务必确保在训练阶段调用 model.train(),在评估阶段调用 model.eval()。即使你的网络中没有使用 Dropout 或 BatchNorm,养成这个习惯也能避免未来添加这些层时忘记切换模式导致错误。
2.5.9 Lazy Module:推迟确定输入维度
前面我们创建 nn.Linear 时,都需要显式写出 in_features:
nn.Linear(in_features=3, out_features=2)这很合理,因为线性层的权重形状是:
\[ W \in \mathbb{R}^{\text{out\_features} \times \text{in\_features}} \]
如果不知道输入最后一维是多少,PyTorch 就没法提前创建这个权重矩阵。
但是在实际写模型时,有些输入维度并不总是方便手算。尤其是卷积网络里,经过多层 Conv2d、Pooling 之后,特征图到底会变成多大,有时候要根据输入尺寸一步步推出来。
例如,一个 CNN 最后通常会把特征图展平成向量,再接一个全连接层:
x = self.features(x)
x = x.flatten(start_dim=1)
x = self.classifier(x)这里 classifier 的 in_features 取决于 features 输出的形状。如果每次改卷积结构或者输入图像大小,都要重新手算这个数字,就会有点麻烦。
为了解决这个问题,PyTorch 提供了一类 Lazy Module。它们的核心思想是:
先创建模块,但暂时不创建完整参数;等第一次看到真实输入时,再根据输入形状初始化参数。
最常用的是 nn.LazyLinear。它不需要我们提前指定 in_features,只需要指定 out_features:
lazy_linear = nn.LazyLinear(out_features=2)
print(lazy_linear)此时,这个模块还没有真正知道输入维度。它的参数是未初始化状态:
for name, param in lazy_linear.named_parameters():
print(f'{name}: {type(param).__name__}')这种参数不是普通的 Parameter,而是 UninitializedParameter。它表示这个参数属于模型,但目前还不知道完整形状。因此,如果我们尝试访问它的形状,就会报错:
try:
print(lazy_linear.weight.shape)
except RuntimeError as err:
print('RuntimeError:', err)当我们第一次把输入传进去时,LazyLinear 会根据输入最后一维推断 in_features,并把参数真正初始化出来:
x = torch.randn(4, 3)
y = lazy_linear(x)
print(lazy_linear)
print('Output shape:', y.shape)现在再看它的参数形状:
for name, param in lazy_linear.named_parameters():
print(f'{name}: {param.size()}')可以看到,weight 的形状已经变成了:
\[ (\text{out\_features}, \text{in\_features}) = (2, 3) \]
也就是说,LazyLinear 第一次看到形状为 (4, 3) 的输入后,自动推断出 in_features = 3,并完成了参数初始化。这就是 lazy module 的 lazy 所在:不是模块不计算,而是参数初始化被推迟到了第一次 forward。
Lazy module 在卷积网络里尤其方便。比如我们可以先写卷积特征提取部分,然后用 nn.LazyLinear 自动适配展平后的维度:
class LazyCNN(nn.Module):
def __init__(self, num_classes: int):
super().__init__()
self.features = nn.Sequential(
nn.LazyConv2d(8, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.LazyConv2d(16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.classifier = nn.LazyLinear(num_classes)
def forward(self, x: Tensor) -> Tensor:
x = self.features(x)
x = x.flatten(start_dim=1)
x = self.classifier(x)
return x创建模型时,我们不需要知道 classifier 的输入维度:
lazy_cnn = LazyCNN(num_classes=10)
print(lazy_cnn)第一次 forward 之后,classifier 才会被具体初始化:
x = torch.randn(4, 1, 28, 28)
y = lazy_cnn(x)
print(lazy_cnn.classifier)这个例子里,输入图像大小是 \(28 \times 28\)。经过两次 MaxPool2d(kernel_size=2) 之后,空间尺寸会从 \(28 \times 28\) 变成 \(7 \times 7\),通道数变成 16。所以展平后的维度是:
\[ 16 \times 7 \times 7 = 784 \]
nn.LazyLinear 正是根据第一次 forward 时的真实输入,自动得到了这个 784。
除了 LazyLinear,PyTorch 里也有 lazy 版本的卷积层,例如:
nn.LazyConv1d(out_channels, kernel_size)
nn.LazyConv2d(out_channels, kernel_size)
nn.LazyConv3d(out_channels, kernel_size)普通卷积层需要指定 in_channels:
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)而 LazyConv2d 可以把 in_channels 推迟到第一次 forward 时再确定:
lazy_conv = nn.LazyConv2d(out_channels=16, kernel_size=3, padding=1)
print(lazy_conv)
x = torch.randn(4, 3, 32, 32)
y = lazy_conv(x)
print(lazy_conv)
print('Output shape:', y.shape)第一次看到输入 (4, 3, 32, 32) 后,LazyConv2d 就知道输入通道数是 3,因此它会把权重初始化成对应形状:
for name, param in lazy_conv.named_parameters():
print(f'{name}: {param.size()}')Lazy module 很方便,但使用时也要注意几点。
第一,在第一次 forward 之前,lazy module 的参数还没有真实形状。因此,有些依赖参数形状的操作不能太早做。例如,你不能在参数初始化前根据 weight.shape 手动写复杂逻辑。
第二,保存模型前,最好先用一个真实 batch 或 dummy batch 跑一次 forward,让所有 lazy 参数都完成初始化。否则 state_dict() 里会包含尚未初始化的参数,后续加载和使用都会更麻烦。
model = LazyCNN(num_classes=10)
x = torch.randn(1, 1, 28, 28)
y = model(x)
for key, value in model.state_dict().items():
print(f'{key}: {value.size()}')第三,lazy module 主要是为了减少手动计算输入维度的负担,而不是为了改变模型结构。第一次 forward 之后,它就会变成一个已经确定形状的普通模块。后续输入的相关维度必须和第一次推断出来的维度匹配。
例如,前面的 LazyLinear 第一次看到的输入最后一维是 3,所以它之后就只能接收最后一维为 3 的输入:
x1 = torch.randn(4, 3)
y1 = lazy_linear(x1)
try:
x2 = torch.randn(4, 5)
y2 = lazy_linear(x2)
except RuntimeError as err:
print('RuntimeError:', err)所以,lazy module 最适合用在这种场景:
- 你知道输出维度,比如分类类别数、隐藏层维度、卷积输出通道数;
- 但输入维度不想手算,或者输入维度由前面的模块自然决定;
- 你愿意在真正训练、保存之前,先让模型跑一次 forward 完成初始化。
总结一下,lazy module 不是新的计算方式,而是一种更方便的模块初始化方式。它把一部分形状信息从 __init__() 推迟到第一次 forward(),让模型代码少写一些硬编码的输入维度。
2.5.10 把这些概念放到一个模型里
现在我们把前面的概念放到一个稍微完整一点的模型里。
这个模型包含:
- 两个线性层,作为可学习的子模块;
- 一个 Dropout,用来展示训练/评估行为;
- 一个输入归一化的 mean 和 std,作为 buffer;
- 一个不可持久化的缓存 mask,作为 non-persistent buffer。
class DemoNet(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.register_buffer('mean', torch.zeros(input_dim))
self.register_buffer('std', torch.ones(input_dim))
self.register_buffer(
'cache_mask',
torch.ones(input_dim, dtype=torch.bool),
persistent=False,
)
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.dropout = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: Tensor) -> Tensor:
x = (x - self.mean) / self.std
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return x创建模型:
demo = DemoNet(input_dim=3, hidden_dim=8, output_dim=2)查看参数:
print('Parameters:')
for name, param in demo.named_parameters():
print(f'{name}: {param.size()}')查看 buffer:
print('Buffers:')
for name, buffer in demo.named_buffers():
print(f'{name}: {buffer.size()}')查看 state dict:
print('State dict:')
for key, value in demo.state_dict().items():
print(f'{key}: {value.size()}')注意,cache_mask 是 buffer,但因为 persistent=False,所以它不会出现在 state_dict() 中。
再查看子模块:
print('Submodules:')
for name, module in demo.named_modules():
print(repr(name), '->', type(module).__name__)这几个输出合起来,就展示了 nn.Module 的核心管理能力:
parameters() -> 找到可学习参数
buffers() -> 找到非参数状态
modules() -> 找到子模块结构
state_dict() -> 导出可保存的模型状态
train()/eval() -> 切换模块行为
到这里,我们就能更完整地理解 nn.Module:它不是某个单独 API,而是 PyTorch 模型系统的中心。只要一个对象继承了 nn.Module,它就进入了 PyTorch 的模型管理体系。
2.5.11 本章小结
这一节我们从手写线性变换出发,理解了为什么需要 nn.Module。
nn.Module 不只是一个带 forward() 的对象,它还负责管理模型中的参数、缓冲区和子模块。nn.Parameter 表示需要被优化器更新的模型参数;Buffer 则表示属于模型状态、需要保存或移动设备,但不应该被优化器更新的张量。
nn.Module 和 nn.functional 的关系可以理解为:前者是带状态的层或模型,后者是不保存状态的函数。很多模块的 forward() 内部都会调用对应的 functional 操作。
我们还看到,Module 可以嵌套 Module,并且 PyTorch 会递归地管理这些子模块中的参数和 buffer。通过 parameters()、buffers()、children() 和 modules(),我们可以查看模型内部结构;通过 state_dict(),可以保存和加载模型状态。
最后,train() 和 eval() 控制的是模块行为,而不是 Autograd 是否记录计算图。验证和推理时,通常既要调用 model.eval(),也要配合 torch.no_grad() 或 torch.inference_mode()。
所以,nn.Module 的核心作用是把散落的张量计算组织成一个真正的模型:它知道哪些东西要学习,哪些东西只是状态,哪些层属于自己,以及这些状态应该如何保存、加载和切换行为。