Chapter 2.2 PyTorch Dimension Transformation Operations
2.2.1 Reshaping
Core functions: view, reshape, flatten
2.2.1.1 torch.view / torch.reshape
- Academic definition: Reassign the shape of a tensor while keeping the total number of elements unchanged.
viewrequires contiguous memory, whilereshapecan create a copy when that is necessary. - Intuitive explanation: Think of the same piece of clay being molded into another form: nothing is added or removed, but the outline changes.
- Simple example (syntax): Arrange a 12-element vector into a matrix.
x = torch.arange(12) # shape: (12)
y = x.view(3, 4) # shape: (3, 4)- Advanced example (LLM Multi-Head Attention): Scenario: In a Transformer, the hidden dimension is commonly divided across multiple attention heads. Suppose
batch_size=32,seq_len=128, andhidden_dim=768, which corresponds to 12 heads with 64 dimensions each.
# Input: [Batch, Seq_Len, Hidden_Dim]
x = torch.randn(32, 128, 768)
# Output: [Batch, Seq_Len, Num_Heads, Head_Dim]
# Use view to split the last dimension
x_heads = x.view(32, 128, 12, 64)2.2.1.2 torch.flatten
- Academic definition: Merge a selected range of dimensions into a single dimension.
- Intuitive explanation: Compress several layers into one flat surface while keeping the underlying content.
- Simple example (syntax):
x = torch.randn(2, 3, 4)
# Flatten the last two dims: (2, 3*4) -> (2, 12)
y = x.flatten(start_dim = 1)- Advanced example (CNN/RL policy network): Scenario: In reinforcement learning with image observations, feature maps from a CNN are usually flattened before entering a fully connected layer.
# Input: [Batch, Channels, Height, Width] -> [32, 64, 7, 7]
features = torch.randn(32, 64, 7, 7)
# Flatten relevant dims for Linear Layer input
# Output: [Batch, Features] -> [32, 64*7*7] -> [32, 3136]
flat_features = features.flatten(start_dim = 1)2.2.2 Swapping
Core functions: permute, transpose
2.2.2.1 torch.permute
- Academic definition: Reorder every dimension according to the index sequence you provide.
- Intuitive explanation: It is like assigning new seats to all dimensions: each axis moves to the position you specify.
- Simple example (syntax):
x = torch.randn(2, 3, 4)
# Change to (4, 2, 3): original dim 2 first, dim 0 second, dim 1 last
y = x.permute(2, 0, 1)- Advanced example (Transformer dimension adjustment): Scenario: In multi-head attention,
Num_Headsis often placed beforeSeq_Lenso matrix multiplication can run in parallel over heads.
# Continuing from the view example: [Batch, Seq_Len, Num_Heads, Head_Dim]
x_heads = torch.randn(32, 128, 12, 64)
# Target: [Batch, Num_Heads, Seq_Len, Head_Dim]
# Swap dimension 1 and dimension 2
query = x_heads.permute(0, 2, 1, 3)2.2.2.2 torch.transpose
- Academic definition: Exchange exactly two selected dimensions.
- Intuitive explanation: It is less general than
permute, but cleaner when only one pair of axes needs to change places. - Advanced example (preparing matrix multiplication): Scenario: When computing attention scores, the Key tensor usually needs its last two dimensions swapped to align with Query.
# Key: [Batch, Heads, Seq_Len, Head_Dim]
K = torch.randn(32, 12, 128, 64)
# Transpose last two dims for matmul: [Batch, Heads, Head_Dim, Seq_Len]
K_t = K.transpose(-1, -2)2.2.3 Adding and Removing Dimensions
Core functions: unsqueeze, squeeze
- Simple example:
unsqueezeinserts a size-1 dimension, whilesqueezeremoves dimensions of size 1. - Advanced example (broadcasting): Scenario: When applying an attention mask in a large model, the mask is often 2D even though it must be added to a 4D attention score tensor.
# Scores: [Batch, Heads, Seq_Len, Seq_Len]
scores = torch.randn(32, 12, 128, 128)
# Mask: [Batch, Seq_Len] (for example, a padding mask)
mask = torch.ones(32, 128)
# Expand dimensions so the mask can be added to scores: [Batch, 1, 1, Seq_Len]
# Insert singleton dimensions at dim 1 and dim 2
mask_expanded = mask.unsqueeze(1).unsqueeze(2)
# Now mask_expanded can be added to scores through broadcasting2.2.4 Combining
Core functions: cat (concatenate), stack
2.2.4.1 torch.cat
- Academic definition: Join tensors along a dimension that already exists.
- Intuitive explanation: Like connecting several segments end to end: the line becomes longer along the same direction.
- Advanced example (KV Cache inference acceleration): Scenario: During autoregressive LLM generation, newly produced Keys/Values are appended to the existing cache.
# Cache: [Batch, Heads, Past_Len, Head_Dim]
past_key = torch.randn(1, 12, 50, 64)
# New Token Key: [Batch, Heads, 1, Head_Dim]
new_key = torch.randn(1, 12, 1, 64)
# Concatenate along the Sequence dimension (dim=2) -> [1, 12, 51, 64]
current_key = torch.cat([past_key, new_key], dim = 2)2.2.4.2 torch.stack
- Academic definition: Create a new dimension and combine tensors along that new axis.
- Intuitive explanation: Instead of extending an old direction, place several tensors into a new layer.
- Advanced example (RL Experience Replay): Scenario: In reinforcement learning, states sampled from several steps in a replay buffer can be organized into one batch.
2.2.5 Summary Comparison Table
| Function | Operation Essence | Typical Academic Scenario (LLM/RL) |
|---|---|---|
view / reshape |
Shape rearrangement (element count unchanged) | Split Hidden_Dim into Num_Heads * Head_Dim |
flatten |
Dimension merging (many dimensions to one) | Prepare CNN features before an MLP decision layer |
permute |
Global reordering (move axes arbitrarily) | Convert (B, L, H, D) to (B, H, L, D) for parallel attention |
transpose |
Two-axis swap (exchange one pair only) | Prepare matrix transposition for attention score computation |
unsqueeze |
Axis insertion (add a size-1 dimension) | Build masks that can broadcast to multi-head attention matrices |
cat |
Join along an old axis (extend an existing dimension) | KV Cache updates during inference; multimodal feature fusion |
stack |
Join along a new axis (create an extra dimension) | Organize multiple time-step states into a batch |