import torch
import torch.nn as nn
import torch.nn.functional as F
x = torch.randn(8, 3)
y = torch.randn(8, 1)
linear = nn.Linear(3, 1)
pred = linear(x)
loss = F.mse_loss(pred, y)
loss.backward()
lr = 0.1
with torch.no_grad():
for param in linear.parameters():
param -= lr * param.gradChapter 2.6 Optimizer in PyTorch: From Manual Updates to Parameter Groups and State Management
The previous sections explained one point: as long as the computation graph is recorded correctly, model parameters receive gradients after loss.backward() is called.
However, gradients do not modify parameters by themselves. They only indicate which direction the parameters should move if we want the loss to decrease. The component that actually updates parameters according to gradients is the optimizer.
The simplest form of gradient descent can be written as:
\[ \theta \leftarrow \theta - \eta \nabla_\theta L \]
Here, \(\theta\) denotes the parameters, \(\nabla_\theta L\) denotes the parameter gradients, and \(\eta\) is the learning rate.
Without an optimizer, parameters can also be updated manually:
This code works, but it quickly leads to several engineering questions:
- Should gradients be cleared before every update?
- What if different parameters need different learning rates?
- If the optimizer uses momentum, where should the momentum state be stored?
- If training needs to be resumed, should optimizer state be saved together with model state?
- If the model is large, will updating parameters one by one be too slow?
These are the problems torch.optim is designed to handle.
This section starts from the most basic optimizer.step() and then explains what PyTorch optimizers manage.
from pprint import pprint
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
print('PyTorch version:', torch.__version__)2.6.1 From Manual Updates to optimizer.step()
Start with a simple linear model:
model = nn.Linear(3, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)When creating an optimizer, the two most important pieces of information are:
- which parameters should be updated;
- what rule should be used to update them.
The line:
optimizer = optim.SGD(model.parameters(), lr=0.1)means: pass all parameters returned by model.parameters() to SGD, and update them with learning rate 0.1.
A minimal training step usually looks like this:
x = torch.randn(8, 3)
y = torch.randn(8, 1)
pred = model(x)
loss = F.mse_loss(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Loss:', loss.item())The following three lines are the most common part:
optimizer.zero_grad()
loss.backward()
optimizer.step()They correspond to three actions in training:
- clear old gradients to avoid accumulation;
- compute new gradients from the current loss and write them into each parameter’s
.grad; - update parameters using the new gradients and optimizer state.
Notice that backward() only writes gradients into the .grad attribute of each parameter. The operation that changes parameter values is optimizer.step().
We can directly check whether a parameter changes before and after an update:
model = nn.Linear(3, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)
before = model.weight.detach().clone()
pred = model(x)
loss = F.mse_loss(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
after = model.weight.detach().clone()
flag = before.allclose(after)
max_err = (before - after).abs().max().item()
print('Is the parameter unchanged?', flag)
print('Max absolute difference:', max_err)This shows that the optimizer has indeed modified the parameter.
2.6.2 Why zero_grad Is Needed Before Each Update
One easy detail to miss is that gradients in PyTorch are accumulated by default rather than overwritten. If backward() is called twice in a row, the gradient from the second call does not replace the first one. It is added to the existing .grad.
Use a very small example to observe this behavior:
w = torch.tensor([1.0], requires_grad=True)
loss1 = w ** 2
loss1.backward()
print('After first backward:', w.grad)
loss2 = w ** 2
loss2.backward()
print('After second backward:', w.grad)After the first backward pass, the gradient is 2. After the second backward pass, the gradient becomes 4. This does not mean the new gradient is 4. It means the old gradient 2 and the new gradient 2 have been accumulated.
This is why training loops usually write:
optimizer.zero_grad()
loss.backward()
optimizer.step()If gradients are not cleared, gradients from every batch will be accumulated on top of previous batches, and the parameter update will no longer correspond to the current batch’s loss.
Gradient accumulation is not always wrong. Sometimes we intentionally accumulate gradients from several mini-batches and update parameters only once. This is called gradient accumulation.
For example:
model = nn.Linear(3, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
for i in range(4):
x = torch.randn(8, 3)
y = torch.randn(8, 1)
pred = model(x)
loss = F.mse_loss(pred, y)
loss = loss / 4
loss.backward()
optimizer.step()Here we intentionally do not clear gradients after every mini-batch. Instead, gradients from 4 mini-batches are accumulated, and step() is called once.
More precisely, zero_grad() is not required by backward() itself. It is a response to PyTorch’s default gradient accumulation behavior. If the current update should use only gradients from the current batch, old gradients should be cleared before backward().
2.6.3 What Is set_to_none?
By default, optimizer.zero_grad() sets each parameter’s .grad to None:
model = nn.Linear(3, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)
x = torch.randn(8, 3)
y = torch.randn(8, 1)
loss = F.mse_loss(model(x), y)
loss.backward()
flag1 = model.weight.grad is None
print('Whether grad is None before zero_grad?', flag1)
optimizer.zero_grad(set_to_none=True)
flag2 = model.weight.grad is None
print('Whether grad is None after zero_grad?', flag2)Sometimes we also see:
optimizer.zero_grad(set_to_none=False)This sets gradients to 0 instead of None.
model = nn.Linear(3, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)
loss = F.mse_loss(model(x), y)
loss.backward()
optimizer.zero_grad(set_to_none=False)
print(model.weight.grad)Both approaches clear old gradients, but their meanings are slightly different:
.grad = None: this parameter currently has no gradient;.grad = 0: this parameter has a gradient tensor, but its values are zero.
In most training code, the default set_to_none=True is enough. It usually uses less memory and lets PyTorch allocate the gradient tensor again during the next backward pass. However, if the code assumes .grad is always a tensor rather than None, this difference matters.
2.6.4 Parameter Groups: Different Parameters Can Use Different Learning Rates
So far, all model parameters have been passed to the same optimizer with the same learning rate:
optimizer = optim.SGD(model.parameters(), lr=1e-3)In real training, different parts of the model often need different hyperparameters.
For example, when fine-tuning a pretrained model, the backbone may use a smaller learning rate, while the final classification head may use a larger one. The backbone has already learned many general features, so we do not want it to change too quickly. The head is randomly initialized and needs to learn the current task faster.
This can be handled with parameter groups:
class TinyModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = nn.Sequential(
nn.Linear(10, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
)
self.head = nn.Linear(32, 2)
def forward(self, x: Tensor) -> Tensor:
x = self.backbone(x)
return self.head(x)
model = TinyModel()
optimizer = optim.SGD(
[
{'params': model.backbone.parameters(), 'lr': 1e-4},
{'params': model.head.parameters(), 'lr': 1e-3},
],
weight_decay=1e-2,
)
for i, group in enumerate(optimizer.param_groups):
print(
f'Parameter group {i}: '
f'learning rate = {group["lr"]}, '
f'weight decay = {group["weight_decay"]}'
)Here, the value passed to the optimizer is no longer a simple parameter iterator. It is a list of dictionaries. Each dictionary describes one group of parameters and the optimization hyperparameters for that group.
If a parameter group does not explicitly set a hyperparameter, it uses the default value from the optimizer constructor. In the example above, neither parameter group sets weight_decay separately, so both groups use the outer weight_decay=1e-2.
Parameter groups are also commonly used to disable weight decay for certain parameters. For example, many training scripts disable weight decay for bias parameters and normalization layer parameters:
def split_weight_decay_params(model: nn.Module):
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if name.endswith('bias') or 'norm' in name.lower():
no_decay.append(param)
else:
decay.append(param)
return [
{'params': decay, 'weight_decay': 1e-2},
{'params': no_decay, 'weight_decay': 0.0},
]
model = nn.Sequential(
nn.Linear(10, 32),
nn.LayerNorm(32),
nn.Linear(32, 2),
)
optimizer = optim.SGD(split_weight_decay_params(model), lr=1e-3)
for i, group in enumerate(optimizer.param_groups):
print(
f'Parameter group {i}: '
f'weight decay = {group["weight_decay"]}, '
f'number of params = {len(group["params"])}'
)Therefore, an optimizer does not have to update the entire model with one set of hyperparameters. Internally, it maintains groups of parameters. Each group can have its own learning rate, weight decay, momentum, and other settings. When needed, model parameters can be split into different groups and managed by the same optimizer.
2.6.5 Optimizers Are Not Only Formulas; They Also Have State
With the simplest SGD and no momentum, the parameter update depends only on the current gradient \(g\):
\[ \theta \leftarrow \theta - \eta g \]
However, many optimizers depend not only on the current gradient, but also on historical information. For example, SGD with momentum maintains a momentum buffer:
\[ v_t = \mu v_{t-1} + g_t \]
Adam and AdamW maintain first-moment and second-moment estimates of gradients.
These historical values are not model parameters, but they affect future updates. Therefore, optimizers also have state.
We can inspect an optimizer’s state_dict() directly:
model = nn.Linear(3, 1)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
pprint(optimizer.state_dict(), sort_dicts=False)Right after the optimizer is created, state is usually empty because no update step has been executed yet.
After one update, inspect it again:
x = torch.randn(8, 3)
y = torch.randn(8, 1)
loss = F.mse_loss(model(x), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
state_dict = optimizer.state_dict()
pprint(state_dict, sort_dicts=False)An optimizer state_dict usually has two parts:
state: internal optimizer state for each parameter;param_groups: parameter group configuration, such as learning rate, weight decay, and betas.
For AdamW, we usually see state entries such as step, exp_avg, and exp_avg_sq. These are historical values used by AdamW during updates. Their specific meanings will be discussed later when AdamW is introduced.
This also explains why saving only model parameters is not enough when resuming training. If only model.state_dict() is saved, model weights can be restored, but SGD momentum and Adam/AdamW first-moment and second-moment estimates are lost.
Therefore, a more complete checkpoint usually contains:
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(checkpoint, 'checkpoint.pth')Correspondingly, both parts should be restored when loading:
model = nn.Linear(3, 1)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])This lets training continue from the interrupted point as closely as possible, instead of restarting optimization from the same model weights with fresh optimizer state.
2.6.6 foreach and fused: Different Implementations of the Same Optimizer
Sometimes optimizer constructors include parameters like this:
optimizer = optim.AdamW(
model.parameters(),
lr=1e-3,
foreach=True,
)or:
optimizer = optim.AdamW(
model.parameters(),
lr=1e-3,
fused=True,
)These parameters do not change the mathematical objective of AdamW. They choose how the parameter update is executed.
In PyTorch, the same optimizer can roughly have three implementation routes:
for-loop: the traditional approach, updating parameter tensors one by one;foreach: package a group of tensors and call batched tensor operations;fused: fuse multiple update operations into fewer kernels.
The easiest one to understand is the for-loop approach. It processes parameters one by one:
for param in params:
update(param)This approach is simple and general. However, if the model has many parameter tensors, it creates many small operations, which can be inefficient on GPUs.
The idea behind foreach is: instead of updating tensors one by one, pass many tensors as a list to the lower-level implementation and process them together. It is usually faster than an ordinary for-loop, especially when there are many parameter tensors on GPU. However, foreach is not free, because it often needs to keep intermediate tensor lists, so peak memory usage can be higher.
fused goes further. It aims to fuse multiple operations inside an optimizer update, reducing repeated kernel launches and intermediate reads and writes. Intuitively, foreach processes many tensors at once, while fused combines multiple operations within one update.
Therefore, in well-supported CUDA scenarios, fused=True may be faster. But it has stricter requirements on device, dtype, and optimizer implementation support.
We can first understand them as three execution strategies:
- for-loop: the most basic and most compatible;
- foreach: usually faster, but may use more memory;
- fused: more aggressive and may be fastest, but support is more limited.
In practice, if there is no special requirement, it is usually fine to let PyTorch choose the default behavior. Only when optimizing large-model training performance, memory usage, or compatibility issues do we need to manually specify foreach or fused.
For the support status of foreach and fused across different PyTorch optimizers, the official documentation provides more detailed notes and compatibility lists. See torch.optim - Algorithms.
The code below only demonstrates how these arguments are passed. Whether fused=True is supported depends on the machine, PyTorch version, and device.
model = nn.Linear(10, 2)
optimizer = optim.AdamW(
model.parameters(),
lr=1e-3,
foreach=False,
fused=False,
)
print(optimizer)2.6.7 optimizer.step() Does Not Record the Computation Graph by Default
When manually updating parameters earlier, we wrote:
with torch.no_grad():
param -= lr * param.gradThis is because ordinary training usually does not need Autograd to record the parameter update itself.
In other words, we generally care about how loss produces gradients with respect to parameters. We do not usually care about differentiating through the optimizer.step() update process itself.
PyTorch optimizers follow this default logic. By default, optimizer.step() updates parameters in a context that does not record gradients.
This is reasonable for ordinary training. If every parameter update were recorded into the computation graph, memory usage would grow quickly and training would become more complex. However, some advanced scenarios do need to differentiate through the optimization process. Examples include meta-learning, differentiable optimization, learning learning rates, or treating several gradient update steps as part of a larger computation graph.
In such cases, the optimizer update process must also participate in Autograd. The corresponding PyTorch optimizer argument is called differentiable. When it is set to True, the optimizer step is tracked, allowing gradients to be computed through the updated parameters.
For example:
model = nn.Linear(3, 1)
optimizer = optim.SGD(
model.parameters(),
lr=0.1,
differentiable=True,
)
flag = optimizer.defaults['differentiable']
print('Is optimizer step differentiable?', flag)However, differentiable=True is not needed in regular training. It makes the optimizer step part of the tracked computation, usually increases memory usage, and may require more careful code.
Therefore, for most training scenarios, the default differentiable=False is enough. Only consider differentiable=True in special cases where gradients must flow through the parameter update process.
2.6.8 A Complete Optimization Step
Now we can put the previous pieces together and write a complete optimization step.
model = nn.Sequential(
nn.Linear(10, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
optimizer = optim.AdamW(
[
{'params': model[0].parameters(), 'lr': 1e-3},
{'params': model[2].parameters(), 'lr': 1e-2},
],
weight_decay=1e-2,
)
x = torch.randn(16, 10)
y = torch.randn(16, 1)
model.train()
pred = model(x)
loss = F.mse_loss(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Loss:', loss.item())Several things happen behind this code:
model(x): runs the forward pass, computes predictionpred, and builds the computation graph;loss.backward(): computes gradients from the currentlossand accumulates them into each model parameter’s.grad;optimizer.step(): updates parameter values using parameter.gradvalues and optimizer internal state;optimizer.zero_grad(): clears old gradients so the nextbackward()call does not continue accumulating them.
The optimizer stores not only hyperparameters, but sometimes also historical state. Parameter groups determine how different parameters are updated. foreach and fused determine how the update process is executed efficiently. differentiable=True determines whether the update process itself enters the computation graph.
2.6.9 Summary
This section started from manual gradient descent and explained the role of PyTorch optimizers. backward() computes gradients and writes them into parameter .grad fields; optimizer.step() actually modifies parameters according to those gradients.
Because PyTorch accumulates gradients by default, regular training needs to call optimizer.zero_grad() before each backward pass. If gradients are intentionally not cleared, gradient accumulation can be implemented.
An optimizer does not have to receive one uniform parameter set. It can also receive multiple parameter groups. Different parameter groups can use different learning rates, weight decay values, and other hyperparameters, which is common in fine-tuning and large-model training.
Optimizers also have their own state. Optimizers such as AdamW store historical gradient statistics, so resuming training usually requires saving both model.state_dict() and optimizer.state_dict().
Finally, we distinguished several optimizer implementation strategies. foreach and fused do not change the mathematical meaning of the optimization algorithm; they change how the update process is executed. For ordinary training, the default settings are usually enough. Further control is mainly needed when performance, memory, or differentiable optimization matters.
The next section will put data loading, model definition, loss computation, and optimizer updates into a complete training loop, showing how they work together.