import torch
x = torch.randn(4, 3)
weight = torch.randn(2, 3, requires_grad=True)
bias = torch.randn(2, requires_grad=True)
y = torch.addmm(bias, x, weight.T)Chapter 2.5 nn.Module in PyTorch: Organizing Models, Parameters, and State
The previous sections explained how PyTorch records computation graphs, how gradients are propagated backward, and how training data is usually organized into mini-batches through Dataset and DataLoader.
One question remains: how should the model itself be organized?
The simplest linear model can be written directly as tensor operations:
This code works. weight and bias participate in the computation graph, and gradients are available after loss.backward().
Once the model becomes slightly more complex, several engineering problems appear:
- Which tensors are parameters that need optimization?
- How should many network layers be organized into one model?
- How should the whole model be moved to GPU?
- How should model state be saved and loaded?
- How should training mode and evaluation mode be distinguished?
If all tensors stay as loose variables, these responsibilities must be managed manually. nn.Module collects them into one object. It is not only an object that can run forward computation, but also PyTorch’s basic container for organizing computation, parameters, buffers, and submodules.
This section starts from a minimal linear layer and then explains several core concepts behind nn.Module.
from pprint import pprint
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
print('PyTorch version:', torch.__version__)2.5.1 Why nn.Module Is Needed
Return to the linear transformation:
\[ y = xW^\top + b \]
If we write it directly with tensors, the code looks like this:
in_features = 3
out_features = 2
weight = torch.randn(out_features, in_features, requires_grad=True)
bias = torch.randn(out_features, requires_grad=True)
x = torch.randn(4, in_features)
y = torch.addmm(bias, x, weight.T)
print(y.shape)In this code, weight and bias are model parameters. But PyTorch does not know that they belong to the same model, because they are only ordinary variables. This creates a problem: if the number of parameters grows later, we need to collect them manually:
parameters = [weight, bias]
pprint(parameters)For a linear model with only two parameters, this is still acceptable. If the model has dozens of layers and hundreds of parameter tensors, manually maintaining the list becomes error-prone.
What we really want is:
- put parameters into a container;
- let the container know which tensors are parameters;
- let the optimizer retrieve these parameters automatically;
- let the container be saved, loaded, and moved to GPU.
This is the problem nn.Module solves.
First write a minimal linear layer by hand:
class SimpleLinear(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.randn(out_features))
def forward(self, x: Tensor) -> Tensor:
return torch.addmm(self.bias, x, self.weight.T)Now weight and bias are no longer loose variables. They belong to the SimpleLinear module:
linear = SimpleLinear(3, 2)
for name, param in linear.named_parameters():
print(f'{name}: {param.shape}')The important change is this: when an nn.Parameter is assigned to an nn.Module attribute, PyTorch automatically registers it as a parameter of that module. To retrieve model parameters, we only need to call:
params = list(linear.named_parameters())
pprint(params)With parameters(), we no longer need to maintain a parameter list manually. Module recursively finds all parameters inside the model. This is the first meaning of nn.Module: it organizes parameters.
2.5.2 forward(): Module Organizes Computation
An nn.Module usually implements a forward() method, which describes how inputs become outputs. For example, the previous SimpleLinear defines:
class SimpleLinear(nn.Module):
...
def forward(self, x: Tensor) -> Tensor:
return torch.addmm(self.bias, x, self.weight.T)When using the module, we usually write:
x = torch.randn(4, 3)
y = linear(x)
print(y.shape)Notice that we call:
linear(x)instead of:
linear.forward(x)This is because nn.Module’s __call__ method calls forward() internally and also handles extra logic such as hooks and checks. When writing models, call module(input) instead of calling module.forward(input) directly.
From this angle, Module is not only a parameter container. It also describes a reusable computation:
Module = parameters + forward computation
There is one point that often causes confusion: if nn.Module contains computation, what is nn.functional?
2.5.3 The Relation Between nn.Module and nn.functional
In PyTorch, two styles are common.
The first is the module style:
linear = nn.Linear(3, 2)
y1 = linear(x)
print(y1.shape)The second is the functional style:
weight = torch.randn(2, 3)
bias = torch.randn(2)
y2 = F.linear(x, weight, bias)
print(y2.shape)Both perform a linear transformation, but their semantics are different.
nn.Linear is an nn.Module. It owns its own weight and bias, and these parameters are registered automatically:
for name, param in linear.named_parameters():
print(f'{name}: {param}')F.linear is a function. It does not store parameters or register any state. You must pass weight and bias explicitly:
F.linear(input, weight, bias)The relation can be summarized as:
nn.Module = layer or model with state
nn.functional = stateless function
Many nn.Module implementations call the corresponding nn.functional operation inside forward(). For example, nn.Linear can be understood roughly as:
def forward(self, input: Tensor) -> Tensor:
return F.linear(input, self.weight, self.bias)This is also why custom modules often mix the two styles:
class SimpleMLP(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return xHere self.fc1 and self.fc2 have learnable parameters, so they are suitable as nn.Module objects. relu has no state to store, so using F.relu is natural.
It is also correct to write:
self.relu = nn.ReLU()The difference is mostly about state and structure. If an operation has no state, the functional style is lighter. If you want it to appear in the module structure, or if it has different training and evaluation behavior, nn.Module is clearer.
For example, Dropout has no learnable parameters, but its behavior differs between training and evaluation. It is usually written as:
self.dropout = nn.Dropout(p=0.5)Then it can switch between train() and eval() together with the whole model.
2.5.4 Parameter: Tensors Updated by the Optimizer
Now look more closely at Parameter.
An ordinary tensor does not automatically become a model parameter after being assigned to a Module, even if it has requires_grad=True:
class BadLinear(nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(2, 3, requires_grad=True)
bad = BadLinear()
for name, param in bad.named_parameters():
print(f'{name}: {param}')Here self.weight is indeed a tensor that requires gradients, but it has not been registered as a Module parameter. Therefore parameters() will not return it, and the optimizer will not update it automatically.
To make a tensor a model parameter, use nn.Parameter:
class GoodLinear(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(2, 3))
good = GoodLinear()
for name, param in good.named_parameters():
print(f'{name}: {param}')nn.Parameter can be understood as a special kind of Tensor. It is special not because of mathematical computation, but because it is automatically registered as a model parameter when assigned to an nn.Module attribute.
The meaning of Parameter is:
This tensor is part of the model and is usually updated by the optimizer.
Weights in linear layers, convolution kernels, and word embedding tables are typical parameters.
Besides direct assignment, parameters can also be registered explicitly with register_parameter():
class ExplicitLinear(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
weight = nn.Parameter(torch.randn(out_features, in_features))
bias = nn.Parameter(torch.randn(out_features))
self.register_parameter('weight', weight)
self.register_parameter('bias', bias)
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)Most of the time, direct assignment is enough:
self.weight = nn.Parameter(...)register_parameter() is more common when parameter names are generated dynamically, or when you need explicit control over whether a name registers a parameter. For example, some modules allow bias to be optional:
class OptionalBiasLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.randn(out_features))
else:
self.register_parameter('bias', None)
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)Even without bias, the module structure remains explicit: it has a position named bias, but currently no parameter is registered there.
2.5.5 Buffer: Model State That Is Not Learnable
Not every tensor inside a model should be updated by the optimizer.
For example, BatchNorm has running mean and running variance. They are updated during training and used for normalization during evaluation, but they are not learned through gradient descent. A precomputed sinusoidal table for positional encoding is another example: it may need to be saved, loaded, and moved with the model, but it should not be updated by the optimizer. Such tensors belong in buffers.
A buffer can be understood as:
A tensor that belongs to model state and should move with the model, but is not a learnable parameter.
Consider a simple normalization module that uses fixed mean and standard deviation:
class Normalize(nn.Module):
def __init__(self, mean: Tensor, std: Tensor):
super().__init__()
self.register_buffer('mean', mean)
self.register_buffer('std', std)
def forward(self, x: Tensor) -> Tensor:
return (x - self.mean) / self.stdHere mean and std should not appear in parameters():
mean = torch.tensor([0.5, 0.5, 0.5])
std = torch.tensor([0.2, 0.2, 0.2])
normalize = Normalize(mean, std)
print('Parameters:')
for name, param in normalize.named_parameters():
print(f'{name}: {param}')
print('Buffers:')
for name, buffer in normalize.named_buffers():
print(f'{name}: {buffer}')But they appear in state_dict():
pprint(normalize.state_dict())This means they are saved together with the model state.
Buffers have another important property: when model.to(device) is called, buffers move to the target device together with parameters. If we only store them as ordinary attributes:
self.mean = mean
self.std = stdthen they will not appear in state_dict() and will not be managed as model state. After registration with register_buffer(), they become part of model state and move with the model.
print('Buffers before moving to device:', normalize.mean.device)
device = torch.accelerator.current_accelerator(check_available=True)
normalize.to(device)
print('Buffers after moving to device:', normalize.mean.device)To decide whether a tensor should be registered as a buffer, check three questions:
- Is it part of the model?
- Should it be saved and loaded?
- Should it move with
model.to(device)?
If all answers are yes, but the tensor should not be updated by the optimizer, it is probably a buffer.
register_buffer() also has a persistent argument. By default, buffers are persistent, which means they are saved into state_dict():
self.register_buffer('mean', mean, persistent=True)If persistent=False, the buffer still moves with the device and can still be found through buffers(), but it will not be saved into state_dict(). This is suitable for caches that can be regenerated, such as intermediate masks or temporary lookup tables.
class MaskCache(nn.Module):
def __init__(self, max_len: int):
super().__init__()
mask = torch.tril(torch.ones(max_len, max_len, dtype=torch.bool))
self.register_buffer('causal_mask', mask, persistent=False)
cache = MaskCache(max_len=4)
print('Buffers:')
for name, buffer in cache.named_buffers():
print(f'{name}: {buffer}')
print('State dict keys:', cache.state_dict())The difference between Parameter and Buffer is:
| Type | Is model state | Updated by optimizer | Included in state_dict | Moves with model.to(device) |
|---|---|---|---|---|
| Parameter | Yes | Usually yes | Yes | Yes |
| Buffer | Yes | No | Yes by default | Yes |
2.5.6 Submodules: Module Can Contain Module
A neural network is usually not one isolated layer. It is a structure composed of many layers. In nn.Module, one module can contain another module.
For example, the previous MLP:
class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 8)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(8, 2)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = SimpleMLP()
print(model)Here fc1, relu, and fc2 are all submodules of model. When an nn.Module is assigned to an attribute of another nn.Module, it is registered automatically.
Therefore, model.parameters() recursively finds parameters inside all submodules:
for name, param in model.named_parameters():
print(f'{name}: {param.size()}')Dots in parameter names indicate module hierarchy. For example:
fc1.weight
fc1.bias
fc2.weight
fc2.bias
This means weight and bias belong to the submodule fc1 or fc2.
PyTorch also provides several common methods for traversing module structure.
children() returns only the direct child modules of the current module:
for child in model.children():
print(child)named_children() returns both names and modules:
for name, child in model.named_children():
print(f'{name}: {child}')modules() recursively returns the current module and all submodules:
for module in model.modules():
print(type(module).__name__)named_modules() recursively returns module names and module objects:
for name, module in model.named_modules():
print(f'{repr(name)} -> {type(module).__name__}')The first name is an empty string, which refers to the model itself.
These methods are useful for debugging model structure, freezing part of a network, or replacing certain submodules.
For example, we can find all linear layers:
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
print('Linear layer:', name)However, if submodules are placed inside an ordinary Python list, PyTorch will not automatically register the modules inside the list.
class BadStack(nn.Module):
def __init__(self):
super().__init__()
self.layers = [nn.Linear(3, 3), nn.Linear(3, 3)]
bad_stack = BadStack()
for name, param in bad_stack.named_parameters():
print(f'{name}: {param.size()}')To store a group of submodules, use nn.ModuleList:
class GoodStack(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList(
[
nn.Linear(3, 3),
nn.Linear(3, 3),
]
)
def forward(self, x: Tensor) -> Tensor:
for layer in self.layers:
x = layer(x)
return x
good_stack = GoodStack()
for name, param in good_stack.named_parameters():
print(f'{name}: {param.size()}')The same applies to dict:
class BadDict(nn.Module):
def __init__(self):
super().__init__()
self.layers = {'layer1': nn.Linear(3, 3), 'layer2': nn.Linear(3, 3)}
def forward(self, x: Tensor) -> Tensor:
for layer in self.layers.keys():
x = self.layers[layer](x)
return x
bad_dict = BadDict()
for name, param in bad_dict.named_parameters():
print(f'{name}: {param.size()}')To store a group of named submodules, use nn.ModuleDict:
class GoodDict(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleDict(
{
'layer1': nn.Linear(3, 3),
'layer2': nn.Linear(3, 3),
}
)
def forward(self, x: Tensor) -> Tensor:
for layer in self.layers.keys():
x = self.layers[layer](x)
return x
good_dict = GoodDict()
for name, param in good_dict.named_parameters():
print(f'{name}: {param.size()}')If modules are executed in a strict sequence, nn.Sequential can also be used:
sequential_model = nn.Sequential(
nn.Linear(3, 8),
nn.ReLU(),
nn.Linear(8, 2),
)
print(sequential_model)ModuleList and ModuleDict are more like registered lists and dictionaries of modules; how forward runs is still defined by us. Sequential directly defines a sequential computation chain.
2.5.7 state_dict: A Dictionary of Model State
After training a model, we usually need to save it. What we normally want to save is not the entire Python object, but the state inside the model.
In PyTorch, this state is represented by state_dict():
state = model.state_dict()
for key, value in state.items():
print(f'{key}: {value.size()}')state_dict is a dictionary from names to tensors. It contains all parameters and persistent buffers. For the previous SimpleMLP, there are only linear-layer parameters and no buffers, so its state_dict mainly contains the weight and bias tensors of fc1 and fc2.
If a module has buffers, they are saved as well:
state = normalize.state_dict()
for key, value in state.items():
print(f'{key}: {value.size()}')This is why buffers are model state even though they are not parameters.
Model parameters are usually saved like this:
torch.save(model.state_dict(), 'model.pt')When loading, first create a model with the same structure, then call load_state_dict():
model = SimpleMLP()
state_dict = torch.load('model.pt')
flag = model.load_state_dict(state_dict)
print(flag)The important idea is:
Model structure is defined by Python code, and model state is saved by state_dict.
state_dict only saves tensors such as weights, bias, and buffers. It does not save the Python logic inside forward(). Therefore, before loading parameters, we need to create a model object with matching structure. If the structure does not match, for example if layer names differ or parameter shapes differ, load_state_dict() will raise an error or return an IncompatibleKeys object. This object usually reports missing keys or unexpected keys: parameters required by the current model but absent from the file, or parameters present in the file but unused by the current model.
In actual model training, state_dict is important because it makes model saving explicit: save state, not the entire runtime environment.
2.5.8 train() and eval(): Switch Module Behavior
Section 2.2 discussed torch.no_grad() and torch.inference_mode(). They control whether the computation graph is recorded. But model.train() and model.eval() control a different thing: whether modules are in training mode or evaluation mode. These two ideas are easy to mix up, but they are not the same.
Start with an example:
dropout = nn.Dropout(p=0.5)
x = torch.ones(5)
dropout.train()
print('Train mode:', dropout(x))
dropout.eval()
print('Eval mode:', dropout(x))In training mode, Dropout randomly drops some elements. In evaluation mode, Dropout no longer drops elements and directly returns the input. This shows that train() and eval() affect the forward behavior of some modules.
The most common affected modules are:
- Dropout: randomly drops elements during training and disables random dropping during evaluation;
- BatchNorm: uses current batch statistics and updates running statistics during training, and uses saved running statistics during evaluation.
We can inspect the training attribute:
model = SimpleMLP()
print(f'Initial training mode: {model.training}')
model.eval()
print(f'After calling eval(): {model.training}')
model.train()
print(f'After calling train(): {model.training}')model.train() sets the model and all submodules to training mode. model.eval() recursively sets them to evaluation mode. It is essentially equivalent to model.train(False).
However, eval() does not turn off automatic differentiation. The following code is in eval mode, but without no_grad(), PyTorch still records the computation graph:
model.eval()
x = torch.randn(4, 3, requires_grad=True)
y = model(x)
print(f'y.requires_grad: {y.requires_grad}')Therefore, validation or inference usually uses both:
model.eval()
with torch.no_grad():
y_pred = model(x)or, for pure inference:
model.eval()
with torch.inference_mode():
y_pred = model(x)A simple distinction is:
train()/eval()control module behavior, such as different computation logic for Dropout and BatchNorm during training and evaluation;no_grad()/inference_mode()control whether Autograd records the computation graph. During evaluation, gradients are usually unnecessary.
So eval() only tells modules to use evaluation behavior. Whether gradients are recorded is still controlled by no_grad() or inference_mode(). This distinction matters.
Make sure to call model.train() during training and model.eval() during evaluation. Even if the current network does not use Dropout or BatchNorm, keeping this habit avoids errors after such layers are added later.
2.5.9 Lazy Module: Defer Input Dimension Inference
When creating nn.Linear, we previously had to specify in_features explicitly:
nn.Linear(in_features=3, out_features=2)This is reasonable because the weight shape of a linear layer is:
\[ W \in \mathbb{R}^{\text{out\_features} \times \text{in\_features}} \]
If PyTorch does not know the last dimension of the input, it cannot create this weight matrix in advance.
In real model code, however, input dimensions are not always convenient to calculate by hand. This is especially common in convolutional networks. After several Conv2d and Pooling layers, the final feature map size may need to be derived step by step from the input size.
For example, a CNN usually flattens the feature map and then applies a fully connected layer:
x = self.features(x)
x = x.flatten(start_dim=1)
x = self.classifier(x)Here the in_features of classifier depends on the output shape of features. If every change to the convolutional structure or input image size requires recalculating this number, the code becomes inconvenient to maintain.
PyTorch provides Lazy Module for this problem. The core idea is:
Create the module first, but do not create complete parameters yet; after the first real input is seen, initialize parameters according to the input shape.
The most common one is nn.LazyLinear. It does not require in_features in advance, only out_features:
lazy_linear = nn.LazyLinear(out_features=2)
print(lazy_linear)At this point, the module does not yet know the input dimension. Its parameters are uninitialized:
for name, param in lazy_linear.named_parameters():
print(f'{name}: {type(param).__name__}')This kind of parameter is not an ordinary Parameter, but an UninitializedParameter. It indicates that the parameter belongs to the model but its full shape is not known yet. Therefore, trying to access its shape raises an error:
try:
print(lazy_linear.weight.shape)
except RuntimeError as err:
print('RuntimeError:', err)When the first input is passed in, LazyLinear infers in_features from the last input dimension and initializes the parameters:
x = torch.randn(4, 3)
y = lazy_linear(x)
print(lazy_linear)
print('Output shape:', y.shape)Now inspect its parameter shapes:
for name, param in lazy_linear.named_parameters():
print(f'{name}: {param.size()}')The shape of weight has become:
\[ (\text{out\_features}, \text{in\_features}) = (2, 3) \]
That is, after LazyLinear first sees an input with shape (4, 3), it infers in_features = 3 and completes parameter initialization. This is what is “lazy” about lazy module: the module is not lazy in computation; parameter initialization is delayed until the first forward pass.
Lazy module is especially useful in convolutional networks. We can write the convolutional feature extractor first and use nn.LazyLinear to adapt automatically to the flattened dimension:
class LazyCNN(nn.Module):
def __init__(self, num_classes: int):
super().__init__()
self.features = nn.Sequential(
nn.LazyConv2d(8, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.LazyConv2d(16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.classifier = nn.LazyLinear(num_classes)
def forward(self, x: Tensor) -> Tensor:
x = self.features(x)
x = x.flatten(start_dim=1)
x = self.classifier(x)
return xWhen creating the model, we do not need to know the input dimension of classifier:
lazy_cnn = LazyCNN(num_classes=10)
print(lazy_cnn)After the first forward pass, classifier is initialized concretely:
x = torch.randn(4, 1, 28, 28)
y = lazy_cnn(x)
print(lazy_cnn.classifier)In this example, the input image size is \(28 \times 28\). After two MaxPool2d(kernel_size=2) operations, the spatial size changes from \(28 \times 28\) to \(7 \times 7\), and the number of channels becomes 16. Therefore, the flattened dimension is:
\[ 16 \times 7 \times 7 = 784 \]
nn.LazyLinear obtains this 784 automatically from the real input during the first forward pass.
Besides LazyLinear, PyTorch also provides lazy versions of convolution layers, such as:
nn.LazyConv1d(out_channels, kernel_size)
nn.LazyConv2d(out_channels, kernel_size)
nn.LazyConv3d(out_channels, kernel_size)Ordinary convolution layers require in_channels:
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)LazyConv2d can defer in_channels until the first forward pass:
lazy_conv = nn.LazyConv2d(out_channels=16, kernel_size=3, padding=1)
print(lazy_conv)
x = torch.randn(4, 3, 32, 32)
y = lazy_conv(x)
print(lazy_conv)
print('Output shape:', y.shape)After seeing an input with shape (4, 3, 32, 32), LazyConv2d knows that the input channel count is 3, so it initializes the weight with the corresponding shape:
for name, param in lazy_conv.named_parameters():
print(f'{name}: {param.size()}')Lazy module is convenient, but several points matter in practice.
First, before the first forward pass, parameters in a lazy module do not have real shapes. Operations that depend on parameter shape cannot be done too early. For example, you cannot write complex logic based on weight.shape before initialization.
Second, before saving a model, it is better to run one real batch or dummy batch through the model so all lazy parameters are initialized. Otherwise, state_dict() may contain uninitialized parameters, which makes later loading and use more cumbersome.
model = LazyCNN(num_classes=10)
x = torch.randn(1, 1, 28, 28)
y = model(x)
for key, value in model.state_dict().items():
print(f'{key}: {value.size()}')Third, lazy module reduces the burden of manually calculating input dimensions; it does not change model structure. After the first forward pass, it becomes an ordinary module with fixed shapes. Later inputs must match the dimensions inferred during the first pass.
For example, the previous LazyLinear first saw an input whose last dimension was 3, so later it can only accept inputs whose last dimension is 3:
x1 = torch.randn(4, 3)
y1 = lazy_linear(x1)
try:
x2 = torch.randn(4, 5)
y2 = lazy_linear(x2)
except RuntimeError as err:
print('RuntimeError:', err)Lazy module is most suitable when:
- you know the output dimension, such as the number of classes, hidden dimension, or convolution output channels;
- you do not want to calculate the input dimension by hand, or the input dimension is naturally determined by previous modules;
- you are willing to run one forward pass before real training and saving so initialization is completed.
In short, lazy module is not a new computation method. It is a more convenient initialization method. It moves part of the shape information from __init__() to the first forward(), reducing hard-coded input dimensions in model code.
2.5.10 Put These Concepts into One Model
Now combine the previous concepts in a slightly more complete model.
This model contains:
- two linear layers as learnable submodules;
- one Dropout layer to show training/evaluation behavior;
- input normalization
meanandstdas buffers; - a non-persistent cache mask as a non-persistent buffer.
class DemoNet(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.register_buffer('mean', torch.zeros(input_dim))
self.register_buffer('std', torch.ones(input_dim))
self.register_buffer(
'cache_mask',
torch.ones(input_dim, dtype=torch.bool),
persistent=False,
)
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.dropout = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: Tensor) -> Tensor:
x = (x - self.mean) / self.std
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return xCreate the model:
demo = DemoNet(input_dim=3, hidden_dim=8, output_dim=2)Inspect parameters:
print('Parameters:')
for name, param in demo.named_parameters():
print(f'{name}: {param.size()}')Inspect buffers:
print('Buffers:')
for name, buffer in demo.named_buffers():
print(f'{name}: {buffer.size()}')Inspect state dict:
print('State dict:')
for key, value in demo.state_dict().items():
print(f'{key}: {value.size()}')Notice that cache_mask is a buffer, but because persistent=False, it does not appear in state_dict().
Then inspect submodules:
print('Submodules:')
for name, module in demo.named_modules():
print(repr(name), '->', type(module).__name__)These outputs together show the core management ability of nn.Module:
parameters() -> find learnable parameters
buffers() -> find non-parameter state
modules() -> find submodule structure
state_dict() -> export saveable model state
train()/eval() -> switch module behavior
At this point, nn.Module can be understood more fully: it is not one isolated API, but the center of PyTorch’s model system. Once an object inherits from nn.Module, it enters PyTorch’s model management system.
2.5.11 Summary
This section started from a hand-written linear transformation and explained why nn.Module is needed.
nn.Module is not only an object with forward(). It also manages parameters, buffers, and submodules inside a model. nn.Parameter represents model parameters that need to be updated by the optimizer. Buffer represents model state that should be saved or moved across devices, but should not be updated by the optimizer.
The relation between nn.Module and nn.functional can be understood as follows: the former is a stateful layer or model, while the latter is a stateless function. Many modules call the corresponding functional operation inside forward().
Module can also contain Module, and PyTorch recursively manages parameters and buffers inside submodules. With parameters(), buffers(), children(), and modules(), we can inspect the internal model structure. With state_dict(), we can save and load model state.
Finally, train() and eval() control module behavior, not whether Autograd records the computation graph. During validation and inference, model.eval() is usually used together with torch.no_grad() or torch.inference_mode().
The core role of nn.Module is to organize loose tensor computation into a real model: it knows which tensors should be learned, which tensors are only state, which layers belong to it, and how these states should be saved, loaded, and switched between training and evaluation behavior.