Chapter 2.5 PyTorch 中的 nn.Module:组织模型、参数与状态

作者

Brench

发布于

2026-06-18

修改于

2026-06-18

前面几节已经说明了 PyTorch 如何记录计算图、如何反向传播梯度,以及训练数据如何通过 Dataset 和 DataLoader 组织成 mini-batch。

但到目前为止,还有一个问题没有处理:模型本身应该怎么组织?

最简单的线性模型可以直接写成张量运算:

import torch

x = torch.randn(4, 3)
weight = torch.randn(2, 3, requires_grad=True)
bias = torch.randn(2, requires_grad=True)

y = torch.addmm(bias, x, weight.T)

这段代码可以工作。weightbias 会参与计算图,loss.backward() 之后也可以得到梯度。

但如果模型稍微复杂一点,我们马上会遇到很多工程问题:

如果所有东西都只是散落在外面的张量,这些问题都要手动管理。nn.Module 的作用,就是把这些职责收拢到一个对象里。它不仅是一个可以前向计算的对象,也是 PyTorch 用来组织计算、参数、缓冲区和子模块的基本容器。

本节从最小的线性层开始,逐步说明 nn.Module 背后的几个核心概念。

from pprint import pprint

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor

print('PyTorch version:', torch.__version__)

2.5.1 为什么需要 nn.Module

我们先回到刚才的线性变换:

\[ y = xW^\top + b \]

如果直接用张量写,代码大概是这样:

in_features = 3
out_features = 2

weight = torch.randn(out_features, in_features, requires_grad=True)
bias = torch.randn(out_features, requires_grad=True)

x = torch.randn(4, in_features)
y = torch.addmm(bias, x, weight.T)

print(y.shape)

这段代码里,weightbias 是模型参数。但是 PyTorch 并不知道它们属于同一个模型,因为它们只是两个普通变量。这会带来一个问题:如果以后参数变多了,我们需要手动收集它们:

parameters = [weight, bias]
pprint(parameters)

对于一个只有两个参数的线性模型,这还可以接受。但如果模型有几十层、上百个参数张量,手动维护列表很容易出错。

我们真正想要的是:

  • 把参数放进一个容器里;
  • 让这个容器知道哪些张量是参数;
  • 让优化器可以自动拿到这些参数;
  • 让这个容器可以被保存、加载、移动到 GPU。

这正是 nn.Module 要解决的问题。

我们先手写一个最小版本的线性层:

class SimpleLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.randn(out_features))

    def forward(self, x: Tensor) -> Tensor:
        return torch.addmm(self.bias, x, self.weight.T)

现在,weightbias 不再是散落在外面的变量,而是属于 SimpleLinear 这个模块:

linear = SimpleLinear(3, 2)

for name, param in linear.named_parameters():
    print(f'{name}: {param.shape}')

这里最重要的变化是:只要把 nn.Parameter 赋值给 nn.Module 的属性,PyTorch 就会自动把它注册成该模块的参数。获取模型参数时,只需要调用:

params = list(linear.named_parameters())
pprint(params)

有了 parameters(),就不需要手动维护参数列表。Module 会把模型内部所有参数递归地找出来。这是 nn.Module 的第一层含义:它是参数的组织方式。

2.5.2 forward():Module 负责组织计算

一个 nn.Module 通常会实现 forward() 方法,用来描述输入如何变成输出。例如前面的 SimpleLinear

class SimpleLinear(nn.Module):
    ...

    def forward(self, x: Tensor) -> Tensor:
        return torch.addmm(self.bias, x, self.weight.T)

使用时,我们通常写:

x = torch.randn(4, 3)
y = linear(x)
print(y.shape)

注意,我们调用的是:

linear(x)

而不是:

linear.forward(x)

这是因为 nn.Module 的 __call__ 方法会在内部调用 forward(),同时还会处理 hooks、参数检查等额外逻辑。平时写模型时,我们应该调用 module(input),而不是直接调用 module.forward(input)

从这个角度看,Module 不只是参数容器,它也描述了一段可复用的计算:

Module = parameters + forward computation

不过,这里还有一个容易混淆的问题:既然 nn.Module 里有计算,那 nn.functional 又是什么?

2.5.3 nn.Module 和 nn.functional 的关系

我们在 PyTorch 里经常会看到两种写法。

第一种是模块写法:

linear = nn.Linear(3, 2)
y1 = linear(x)
print(y1.shape)

第二种是函数式写法:

weight = torch.randn(2, 3)
bias = torch.randn(2)
y2 = F.linear(x, weight, bias)
print(y2.shape)

它们都能完成线性变换,但语义不一样。

nn.Linear 是一个 nn.Module。它内部持有自己的 weightbias,并且这些参数会被自动注册:

for name, param in linear.named_parameters():
    print(f'{name}: {param}')

F.linear 是一个函数。它不会保存参数,也不会注册任何状态。你必须显式把 weightbias 传进去:

F.linear(input, weight, bias)

所以,二者的关系可以简单理解为:

nn.Module      = 带状态的层或模型
nn.functional  = 不保存状态的函数

很多 nn.Module 的 forward() 内部,本质上就是调用对应的 nn.functional 函数。例如,nn.Linear 内部大致可以理解成:

def forward(self, input: Tensor) -> Tensor:
    return F.linear(input, self.weight, self.bias)

这也是为什么我们在自定义模块时,经常会混合使用两者:

class SimpleMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

这里的 self.fc1self.fc2 有可学习参数,所以适合写成 nn.Module;而 relu 没有需要保存的参数,直接用 F.relu 就很自然。

也可以写成:

self.relu = nn.ReLU()

这同样正确。区别主要在于:如果一个操作没有状态,用 functional 写法更轻;如果希望它出现在模块结构里,或者它本身有训练/评估行为,用 nn.Module 更清晰。

例如,Dropout 虽然没有可学习参数,但它在训练和评估时行为不同,所以通常写成:

self.dropout = nn.Dropout(p=0.5)

这样它就能跟随整个模型一起切换 train()eval() 模式。

2.5.4 Parameter:需要被优化器更新的张量

接下来具体看 Parameter

普通张量即使设置了 requires_grad=True,赋值给 Module 后也不会自动成为模型参数:

class BadLinear(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.randn(2, 3, requires_grad=True)


bad = BadLinear()
for name, param in bad.named_parameters():
    print(f'{name}: {param}')

这里 self.weight 确实是一个需要梯度的张量,但它没有被注册成 Module 的参数。因此 parameters() 不会返回它,优化器也不会自动更新它。

如果希望一个张量成为模型参数,需要使用 nn.Parameter:

class GoodLinear(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(2, 3))


good = GoodLinear()
for name, param in good.named_parameters():
    print(f'{name}: {param}')

nn.Parameter 可以理解成一种特殊的 Tensor。它的特殊之处不在于数学计算,而在于:当它被赋值给 nn.Module 的属性时,会被自动注册为模型参数。

所以,Parameter 的含义是:

这是模型的一部分,并且通常需要被优化器更新。

例如线性层的权重、卷积核、词嵌入表,都是典型的 parameter。

除了直接赋值,我们也可以用 register_parameter() 显式注册参数:

class ExplicitLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        weight = nn.Parameter(torch.randn(out_features, in_features))
        bias = nn.Parameter(torch.randn(out_features))

        self.register_parameter('weight', weight)
        self.register_parameter('bias', bias)

    def forward(self, x: Tensor) -> Tensor:
        return F.linear(x, self.weight, self.bias)

大多数时候,直接写:

self.weight = nn.Parameter(...)

就够了。register_parameter() 更常见于参数名需要动态生成,或者需要显式控制某个名字是否注册参数的情况。例如,有些模块可以选择是否使用 bias:

class OptionalBiasLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features))

        if bias:
            self.bias = nn.Parameter(torch.randn(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, x: Tensor) -> Tensor:
        return F.linear(x, self.weight, self.bias)

这样,即使没有 bias,这个模块的结构也很明确:它有一个名为 bias 的位置,只是当前没有参数。

2.5.5 Buffer:属于模型状态,但不是可学习参数

并不是模型里的所有张量都应该被优化器更新。

例如,BatchNorm 里有 running mean 和 running variance。它们会随着训练数据更新,用于评估阶段的归一化,但它们不是通过梯度下降学出来的参数。再比如,位置编码中预先计算好的 sinusoidal table,也可能希望跟随模型一起保存、加载和移动设备,但不希望优化器更新它。这类张量就适合放在 buffer 里。

Buffer 可以理解成:

属于模型状态、需要跟着模型走,但不是可学习参数的张量。

看一个简单例子。假设有一个标准化模块,它用固定的均值和标准差对输入做归一化:

class Normalize(nn.Module):
    def __init__(self, mean: Tensor, std: Tensor):
        super().__init__()
        self.register_buffer('mean', mean)
        self.register_buffer('std', std)

    def forward(self, x: Tensor) -> Tensor:
        return (x - self.mean) / self.std

这里的 meanstd 不应该出现在 parameters() 里:

mean = torch.tensor([0.5, 0.5, 0.5])
std = torch.tensor([0.2, 0.2, 0.2])
normalize = Normalize(mean, std)

print('Parameters:')
for name, param in normalize.named_parameters():
    print(f'{name}: {param}')

print('Buffers:')
for name, buffer in normalize.named_buffers():
    print(f'{name}: {buffer}')

但它们会出现在 state_dict() 里:

pprint(normalize.state_dict())

这说明它们会随着模型状态一起保存。

Buffer 还有一个重要特点:当我们调用 model.to(device) 时,buffer 会和 parameter 一起被移动到对应设备。比如,如果我们只是把它们保存成普通属性:

self.mean = mean
self.std = std

那么它们既不会出现在 state_dict() 里,也不会作为模型状态被统一管理。用 register_buffer() 注册后,它们会被视为模型状态的一部分,跟随模型一起移动到对应设备。

print('Buffers before moving to device:', normalize.mean.device)

device = torch.accelerator.current_accelerator(check_available=True)
normalize.to(device)

print('Buffers after moving to device:', normalize.mean.device)

判断一个张量要不要注册成 buffer,可以看三个问题:

  • 它是不是模型的一部分?
  • 它要不要保存和加载?
  • 它要不要跟随 model.to(device) 移动?

如果答案都是肯定的,但它又不是优化器要更新的参数,那么它很可能应该是 buffer。

register_buffer() 还有一个参数叫 persistent。默认情况下,buffer 是 persistent 的,也就是会保存到 state_dict() 里:

self.register_buffer('mean', mean, persistent=True)

如果设置成 persistent=False,这个 buffer 仍然会跟随设备移动,也能通过 buffers() 找到,但不会保存到 state_dict() 里。这适合一些可以重新生成的缓存,例如某些中间 mask、临时查表结果等。

class MaskCache(nn.Module):
    def __init__(self, max_len: int):
        super().__init__()
        mask = torch.tril(torch.ones(max_len, max_len, dtype=torch.bool))
        self.register_buffer('causal_mask', mask, persistent=False)


cache = MaskCache(max_len=4)
print('Buffers:')
for name, buffer in cache.named_buffers():
    print(f'{name}: {buffer}')

print('State dict keys:', cache.state_dict())

总结一下,Parameter 和 Buffer 的区别是:

表 1:Parameter 和 Buffer 的区别
类型 是否是模型状态 是否被优化器更新 是否进入 state_dict 是否跟随 model.to(device)
Parameter 通常是
Buffer 默认是

2.5.6 子模块:Module 可以嵌套 Module

我们知道,神经网络通常不是一个单独的层,而是很多层组合起来的结构。在 nn.Module 里,一个模块可以包含另一个模块。

比如前面的 MLP:

class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3, 8)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(8, 2)

    def forward(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


model = SimpleMLP()
print(model)

这里 fc1relufc2 都是 model 的子模块。只要把一个 nn.Module 赋值给另一个 nn.Module 的属性,它也会被自动注册。

因此,model.parameters() 会递归地找到所有子模块里的参数:

for name, param in model.named_parameters():
    print(f'{name}: {param.size()}')

参数名里的点号表示模块层级。例如:

fc1.weight
fc1.bias
fc2.weight
fc2.bias

说明 weightbias 属于子模块 fc1fc2

PyTorch 也提供了几组常用方法来遍历模块结构。

children() 只返回当前模块的直接子模块:

for child in model.children():
    print(child)

named_children() 会同时返回名字:

for name, child in model.named_children():
    print(f'{name}: {child}')

modules() 会递归返回当前模块和所有子模块:

for module in model.modules():
    print(type(module).__name__)

named_modules() 会递归返回模块名和模块对象:

for name, module in model.named_modules():
    print(f'{repr(name)} -> {type(module).__name__}')

这里第一个名字是空字符串,表示模型本身。

这些方法在调试模型结构、冻结部分层、替换某些子模块时都很常用。

例如,我们可以找到所有线性层:

for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        print('Linear layer:', name)

但是,如果我们把子模块放进普通 Python list 里,PyTorch 不会自动注册 list 里面的模块。

class BadStack(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = [nn.Linear(3, 3), nn.Linear(3, 3)]


bad_stack = BadStack()
for name, param in bad_stack.named_parameters():
    print(f'{name}: {param.size()}')

如果想保存一组子模块,应该使用 nn.ModuleList

class GoodStack(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                nn.Linear(3, 3),
                nn.Linear(3, 3),
            ]
        )

    def forward(self, x: Tensor) -> Tensor:
        for layer in self.layers:
            x = layer(x)
        return x


good_stack = GoodStack()
for name, param in good_stack.named_parameters():
    print(f'{name}: {param.size()}')

对于 dict 也是:

class BadDict(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = {'layer1': nn.Linear(3, 3), 'layer2': nn.Linear(3, 3)}

    def forward(self, x: Tensor) -> Tensor:
        for layer in self.layers.keys():
            x = self.layers[layer](x)
        return x


bad_dict = BadDict()
for name, param in bad_dict.named_parameters():
    print(f'{name}: {param.size()}')

如果想保存一组命名的子模块,应该使用 nn.ModuleDict

class GoodDict(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleDict(
            {
                'layer1': nn.Linear(3, 3),
                'layer2': nn.Linear(3, 3),
            }
        )

    def forward(self, x: Tensor) -> Tensor:
        for layer in self.layers.keys():
            x = self.layers[layer](x)
        return x


good_dict = GoodDict()
for name, param in good_dict.named_parameters():
    print(f'{name}: {param.size()}')

如果模块之间是严格顺序执行,也可以使用 nn.Sequential

sequential_model = nn.Sequential(
    nn.Linear(3, 8),
    nn.ReLU(),
    nn.Linear(8, 2),
)

print(sequential_model)

ModuleListModuleDict 更像是注册过的模块列表和字典,forward 怎么写由我们决定;Sequential 则直接定义了一条顺序计算链。

2.5.7 state_dict:模型状态的字典

每次我们训练完模型,我们要把训练好的模型保存下来。保存模型时,我们真正想保存的通常不是整个 Python 对象,而是模型里的状态。

在 PyTorch 中,这个状态由 state_dict() 表示:

state = model.state_dict()

for key, value in state.items():
    print(f'{key}: {value.size()}')

state_dict 是一个从名字到张量的字典。它包含所有的 parameter 和 persistent buffer。比如我们前面的 SimpleMLP,只有线性层参数,没有 buffer,所以它的 state_dict 主要是 fc1fc2 的 weight 和 bias。

如果模块里有 buffer,它也会保存进去:

state = normalize.state_dict()

for key, value in state.items():
    print(f'{key}: {value.size()}')

这也是为什么前面说,buffer 虽然不是参数,但仍然是模型状态。

通常保存模型参数可以这样写:

torch.save(model.state_dict(), 'model.pt')

加载时,先重新创建同样结构的模型,再调用 load_state_dict()

model = SimpleMLP()
state_dict = torch.load('model.pt')
flag = model.load_state_dict(state_dict)
print(flag)

这里有一个很重要的思想:

模型结构由 Python 代码定义,模型状态由 state_dict 保存。

state_dict 只保存权重、bias、buffer 等张量,不保存 forward() 的 Python 逻辑。因此加载参数前,我们需要先创建一个结构匹配的模型对象,并将其实例化。如果结构不匹配,例如网络层的名字不同、参数形状不同,load_state_dict() 就会报错或者返回 IncompatibleKeys 对象。这个对象通常会告诉我们 missing keys 或 unexpected keys,也就是当前模型需要但文件里没有的参数,或者文件里有但当前模型用不到的参数。

在实际模型训练中,state_dict 非常重要,因为它让模型保存变得更加清晰:保存的是状态,而不是整个运行环境。

2.5.8 train() 和 eval():切换模块的行为

前面 2.2 节里,我们讨论过 torch.no_grad()torch.inference_mode()。它们控制的是:是否记录计算图。但是,model.train()model.eval() 控制的是另一件事:模块处于训练模式还是评估模式。这两个概念很容易混在一起,但它们不是一回事。

我们先看一个例子:

dropout = nn.Dropout(p=0.5)
x = torch.ones(5)

dropout.train()
print('Train mode:', dropout(x))

dropout.eval()
print('Eval mode:', dropout(x))

在训练模式下,Dropout 会随机丢弃一部分元素;在评估模式下,Dropout 不再随机丢弃,而是直接返回输入。这说明 train()eval() 会影响某些模块的 forward 行为。

最常见受影响的模块是:

  • Dropout:训练时随机丢弃,评估时关闭随机丢弃;
  • BatchNorm:训练时使用当前 batch 统计量并更新 running statistics,评估时使用保存的 running statistics。

我们可以看一下 training 属性:

model = SimpleMLP()
print(f'Initial training mode: {model.training}')

model.eval()
print(f'After calling eval(): {model.training}')

model.train()
print(f'After calling train(): {model.training}')

model.train() 会把模型及其所有子模块都设置为训练模式;model.eval() 会递归地把它们设置为评估模式。它本质上等价于 model.train(False)

但是,eval() 不会关闭自动微分。也就是说,下面这段代码虽然处在 eval 模式,但如果没有 no_grad(),PyTorch 仍然会记录计算图:

model.eval()
x = torch.randn(4, 3, requires_grad=True)
y = model(x)
print(f'y.requires_grad: {y.requires_grad}')

因此,验证或推理时通常需要同时写:

model.eval()
with torch.no_grad():
    y_pred = model(x)

或者在纯推理场景中写:

model.eval()
with torch.inference_mode():
    y_pred = model(x)

所以,一个简单的区分方式是:

  • train() / eval() 控制模块的行为,例如 Dropout 和 BatchNorm 在训练和评估时的不同计算逻辑;
  • no_grad() / inference_mode() 控制 Autograd 是否记录计算图,例如在评估阶段我们通常不需要梯度。

所以,eval() 只是告诉模块,现在请使用评估阶段的行为。至于是否记录梯度,还要由 no_grad()inference_mode() 来控制。这一点很重要。

警告

请务必确保在训练阶段调用 model.train(),在评估阶段调用 model.eval()。即使你的网络中没有使用 Dropout 或 BatchNorm,养成这个习惯也能避免未来添加这些层时忘记切换模式导致错误。

2.5.9 Lazy Module:推迟确定输入维度

前面我们创建 nn.Linear 时,都需要显式写出 in_features

nn.Linear(in_features=3, out_features=2)

这很合理,因为线性层的权重形状是:

\[ W \in \mathbb{R}^{\text{out\_features} \times \text{in\_features}} \]

如果不知道输入最后一维是多少,PyTorch 就没法提前创建这个权重矩阵。

但是在实际写模型时,有些输入维度并不总是方便手算。尤其是卷积网络里,经过多层 Conv2dPooling 之后,特征图到底会变成多大,有时候要根据输入尺寸一步步推出来。

例如,一个 CNN 最后通常会把特征图展平成向量,再接一个全连接层:

x = self.features(x)
x = x.flatten(start_dim=1)
x = self.classifier(x)

这里 classifierin_features 取决于 features 输出的形状。如果每次改卷积结构或者输入图像大小,都要重新手算这个数字,就会有点麻烦。

为了解决这个问题,PyTorch 提供了一类 Lazy Module。它们的核心思想是:

先创建模块,但暂时不创建完整参数;等第一次看到真实输入时,再根据输入形状初始化参数。

最常用的是 nn.LazyLinear。它不需要我们提前指定 in_features,只需要指定 out_features

lazy_linear = nn.LazyLinear(out_features=2)
print(lazy_linear)

此时,这个模块还没有真正知道输入维度。它的参数是未初始化状态:

for name, param in lazy_linear.named_parameters():
    print(f'{name}: {type(param).__name__}')

这种参数不是普通的 Parameter,而是 UninitializedParameter。它表示这个参数属于模型,但目前还不知道完整形状。因此,如果我们尝试访问它的形状,就会报错:

try:
    print(lazy_linear.weight.shape)
except RuntimeError as err:
    print('RuntimeError:', err)

当我们第一次把输入传进去时,LazyLinear 会根据输入最后一维推断 in_features,并把参数真正初始化出来:

x = torch.randn(4, 3)
y = lazy_linear(x)

print(lazy_linear)
print('Output shape:', y.shape)

现在再看它的参数形状:

for name, param in lazy_linear.named_parameters():
    print(f'{name}: {param.size()}')

可以看到,weight 的形状已经变成了:

\[ (\text{out\_features}, \text{in\_features}) = (2, 3) \]

也就是说,LazyLinear 第一次看到形状为 (4, 3) 的输入后,自动推断出 in_features = 3,并完成了参数初始化。这就是 lazy module 的 lazy 所在:不是模块不计算,而是参数初始化被推迟到了第一次 forward。

Lazy module 在卷积网络里尤其方便。比如我们可以先写卷积特征提取部分,然后用 nn.LazyLinear 自动适配展平后的维度:

class LazyCNN(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.features = nn.Sequential(
            nn.LazyConv2d(8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.LazyConv2d(16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.classifier = nn.LazyLinear(num_classes)

    def forward(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = x.flatten(start_dim=1)
        x = self.classifier(x)
        return x

创建模型时,我们不需要知道 classifier 的输入维度:

lazy_cnn = LazyCNN(num_classes=10)
print(lazy_cnn)

第一次 forward 之后,classifier 才会被具体初始化:

x = torch.randn(4, 1, 28, 28)
y = lazy_cnn(x)

print(lazy_cnn.classifier)

这个例子里,输入图像大小是 \(28 \times 28\)。经过两次 MaxPool2d(kernel_size=2) 之后,空间尺寸会从 \(28 \times 28\) 变成 \(7 \times 7\),通道数变成 16。所以展平后的维度是:

\[ 16 \times 7 \times 7 = 784 \]

nn.LazyLinear 正是根据第一次 forward 时的真实输入,自动得到了这个 784。

除了 LazyLinear,PyTorch 里也有 lazy 版本的卷积层,例如:

nn.LazyConv1d(out_channels, kernel_size)
nn.LazyConv2d(out_channels, kernel_size)
nn.LazyConv3d(out_channels, kernel_size)

普通卷积层需要指定 in_channels

nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)

LazyConv2d 可以把 in_channels 推迟到第一次 forward 时再确定:

lazy_conv = nn.LazyConv2d(out_channels=16, kernel_size=3, padding=1)
print(lazy_conv)

x = torch.randn(4, 3, 32, 32)
y = lazy_conv(x)

print(lazy_conv)
print('Output shape:', y.shape)

第一次看到输入 (4, 3, 32, 32) 后,LazyConv2d 就知道输入通道数是 3,因此它会把权重初始化成对应形状:

for name, param in lazy_conv.named_parameters():
    print(f'{name}: {param.size()}')

Lazy module 很方便,但使用时也要注意几点。

第一,在第一次 forward 之前,lazy module 的参数还没有真实形状。因此,有些依赖参数形状的操作不能太早做。例如,你不能在参数初始化前根据 weight.shape 手动写复杂逻辑。

第二,保存模型前,最好先用一个真实 batch 或 dummy batch 跑一次 forward,让所有 lazy 参数都完成初始化。否则 state_dict() 里会包含尚未初始化的参数,后续加载和使用都会更麻烦。

model = LazyCNN(num_classes=10)

x = torch.randn(1, 1, 28, 28)
y = model(x)

for key, value in model.state_dict().items():
    print(f'{key}: {value.size()}')

第三,lazy module 主要是为了减少手动计算输入维度的负担,而不是为了改变模型结构。第一次 forward 之后,它就会变成一个已经确定形状的普通模块。后续输入的相关维度必须和第一次推断出来的维度匹配。

例如,前面的 LazyLinear 第一次看到的输入最后一维是 3,所以它之后就只能接收最后一维为 3 的输入:

x1 = torch.randn(4, 3)
y1 = lazy_linear(x1)

try:
    x2 = torch.randn(4, 5)
    y2 = lazy_linear(x2)
except RuntimeError as err:
    print('RuntimeError:', err)

所以,lazy module 最适合用在这种场景:

  • 你知道输出维度,比如分类类别数、隐藏层维度、卷积输出通道数;
  • 但输入维度不想手算,或者输入维度由前面的模块自然决定;
  • 你愿意在真正训练、保存之前,先让模型跑一次 forward 完成初始化。

总结一下,lazy module 不是新的计算方式,而是一种更方便的模块初始化方式。它把一部分形状信息从 __init__() 推迟到第一次 forward(),让模型代码少写一些硬编码的输入维度。

2.5.10 把这些概念放到一个模型里

现在我们把前面的概念放到一个稍微完整一点的模型里。

这个模型包含:

  • 两个线性层,作为可学习的子模块;
  • 一个 Dropout,用来展示训练/评估行为;
  • 一个输入归一化的 mean 和 std,作为 buffer;
  • 一个不可持久化的缓存 mask,作为 non-persistent buffer。
class DemoNet(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.register_buffer('mean', torch.zeros(input_dim))
        self.register_buffer('std', torch.ones(input_dim))
        self.register_buffer(
            'cache_mask',
            torch.ones(input_dim, dtype=torch.bool),
            persistent=False,
        )

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x: Tensor) -> Tensor:
        x = (x - self.mean) / self.std
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

创建模型:

demo = DemoNet(input_dim=3, hidden_dim=8, output_dim=2)

查看参数:

print('Parameters:')
for name, param in demo.named_parameters():
    print(f'{name}: {param.size()}')

查看 buffer:

print('Buffers:')
for name, buffer in demo.named_buffers():
    print(f'{name}: {buffer.size()}')

查看 state dict:

print('State dict:')
for key, value in demo.state_dict().items():
    print(f'{key}: {value.size()}')

注意,cache_mask 是 buffer,但因为 persistent=False,所以它不会出现在 state_dict() 中。

再查看子模块:

print('Submodules:')
for name, module in demo.named_modules():
    print(repr(name), '->', type(module).__name__)

这几个输出合起来,就展示了 nn.Module 的核心管理能力:

parameters()    -> 找到可学习参数
buffers()       -> 找到非参数状态
modules()       -> 找到子模块结构
state_dict()    -> 导出可保存的模型状态
train()/eval()  -> 切换模块行为

到这里,我们就能更完整地理解 nn.Module:它不是某个单独 API,而是 PyTorch 模型系统的中心。只要一个对象继承了 nn.Module,它就进入了 PyTorch 的模型管理体系。

2.5.11 本章小结

这一节我们从手写线性变换出发,理解了为什么需要 nn.Module。

nn.Module 不只是一个带 forward() 的对象,它还负责管理模型中的参数、缓冲区和子模块。nn.Parameter 表示需要被优化器更新的模型参数;Buffer 则表示属于模型状态、需要保存或移动设备,但不应该被优化器更新的张量。

nn.Module 和 nn.functional 的关系可以理解为:前者是带状态的层或模型,后者是不保存状态的函数。很多模块的 forward() 内部都会调用对应的 functional 操作。

我们还看到,Module 可以嵌套 Module,并且 PyTorch 会递归地管理这些子模块中的参数和 buffer。通过 parameters()buffers()children()modules(),我们可以查看模型内部结构;通过 state_dict(),可以保存和加载模型状态。

最后,train()eval() 控制的是模块行为,而不是 Autograd 是否记录计算图。验证和推理时,通常既要调用 model.eval(),也要配合 torch.no_grad()torch.inference_mode()

所以,nn.Module 的核心作用是把散落的张量计算组织成一个真正的模型:它知道哪些东西要学习,哪些东西只是状态,哪些层属于自己,以及这些状态应该如何保存、加载和切换行为。

二次使用