Chapter 2.2 PyTorch 维度变换操作函数

作者

Brench

发布于

2026-05-11

修改于

2026-05-11

2.2.1 重塑类 (Reshaping)

核心函数: view, reshape, flatten

2.2.1.1 torch.view / torch.reshape

  • 学术定义: 在元素数量保持一致的前提下,重新指定张量的形状(Shape)。view 要求底层内存连续(Contiguous),reshape 会在必要时自行处理拷贝。
  • 通俗解释: 就像同一团橡皮泥换个外形,材料没少,只是从“长条”变成了“方块”。
  • 简单例子 (理解语法): 把一个包含 12 个元素的一维向量整理成一个矩阵。
x = torch.arange(12) # shape: (12)
y = x.view(3, 4)     # shape: (3, 4)
  • 进阶例子 (LLM Multi-head Attention): 场景: Transformer 里经常需要把隐藏层维度切分成多个注意力头(Heads)。 假设 batch_size=32, seq_len=128, hidden_dim=768,也就是 12 个头,每个头 64 维。
# Input: [Batch, Seq_Len, Hidden_Dim]
x = torch.randn(32, 128, 768)

# Output: [Batch, Seq_Len, Num_Heads, Head_Dim]
# 使用 view 将最后一维拆开
x_heads = x.view(32, 128, 12, 64)

2.2.1.2 torch.flatten

  • 学术定义: 把一段连续维度合并成一个维度。
  • 通俗解释: 像把多层结构压成一张平面,层数不见了,但内容还在。
  • 简单例子 (理解语法):
x = torch.randn(2, 3, 4)
# 把后面两维拍扁: (2, 3*4) -> (2, 12)
y = x.flatten(start_dim = 1)
  • 进阶例子 (CNN/RL 策略网络): 场景: 强化学习里处理图像观测时,卷积得到的 Feature Map 通常要先摊平成向量,再送入全连接层。
# Input: [Batch, Channels, Height, Width] -> [32, 64, 7, 7]
features = torch.randn(32, 64, 7, 7)

# Flatten relevant dims for Linear Layer input
# Output: [Batch, Features] -> [32, 64*7*7] -> [32, 3136]
flat_features = features.flatten(start_dim = 1)

2.2.2 交换类 (Swapping)

核心函数: permute, transpose

2.2.2.1 torch.permute

  • 学术定义: 根据给定的维度索引顺序,对所有维度重新排序。
  • 通俗解释: 相当于给每个维度重新排座位,哪一维放前面、哪一维放后面都由你指定。
  • 简单例子 (理解语法):
x = torch.randn(2, 3, 4)
# 变成 (4, 2, 3) -> 原来的第2维放最前,第0维放中间,第1维放最后
y = x.permute(2, 0, 1)
  • 进阶例子 (Transformer 维度调整): 场景: 在 Multi-head Attention 中,为了按 head 并行做矩阵乘法,常常需要把 Num_Heads 放到 Seq_Len 前面。
# 接上面的 view 例子: [Batch, Seq_Len, Num_Heads, Head_Dim]
x_heads = torch.randn(32, 128, 12, 64)

# 目标: [Batch, Num_Heads, Seq_Len, Head_Dim]
# 交换维度1和维度2
query = x_heads.permute(0, 2, 1, 3)

2.2.2.2 torch.transpose

  • 学术定义: 只对指定的两个维度进行互换。
  • 通俗解释: 它不像 permute 那样能重新排列全部维度,但在只换一对维度时更直接。
  • 进阶例子 (矩阵乘法准备): 场景: 计算 Attention Score 时,Key 张量通常要交换最后两个维度,才能和 Query 对齐做乘法。
# Key: [Batch, Heads, Seq_Len, Head_Dim]
K = torch.randn(32, 12, 128, 64)

# Transpose last two dims for matmul: [Batch, Heads, Head_Dim, Seq_Len]
K_t = K.transpose(-1, -2)

2.2.3 增减类 (Adding/Removing)

核心函数: unsqueeze, squeeze

  • 简单例子: unsqueeze 插入一个大小为 1 的维度,squeeze 则把这类单元素维度去掉。
  • 进阶例子 (广播机制 Broadcasting): 场景: 大模型里使用 Attention Mask 时,Mask 往往是 2D 形状,但实际要和 4D 的 Attention Score 相加。
# Scores: [Batch, Heads, Seq_Len, Seq_Len]
scores = torch.randn(32, 12, 128, 128)

# Mask: [Batch, Seq_Len] (比如 padding mask)
mask = torch.ones(32, 128)

# 为了让 Mask 能加到 Scores 上,需要扩展维度: [Batch, 1, 1, Seq_Len]
# 在第1维和第2维插入维度
mask_expanded = mask.unsqueeze(1).unsqueeze(2)

# 此时 mask_expanded 可以通过广播机制与 scores 相加

2.2.4 拼接类 (Combining)

核心函数: cat (concatenate), stack

2.2.4.1 torch.cat

  • 学术定义: 沿已有维度把多个张量接在一起。
  • 通俗解释: 像把几段绳子首尾相接,长度增加,但方向还是原来的方向。
  • 进阶例子 (KV Cache 推理加速): 场景: LLM 自回归生成时,会把新 token 对应的 Key/Value 接到已有缓存后面。
# Cache: [Batch, Heads, Past_Len, Head_Dim]
past_key = torch.randn(1, 12, 50, 64)

# New Token Key: [Batch, Heads, 1, Head_Dim]
new_key = torch.randn(1, 12, 1, 64)

# 在 Sequence 维度 (dim=2) 拼接 -> [1, 12, 51, 64]
current_key = torch.cat([past_key, new_key], dim = 2)

2.2.4.2 torch.stack

  • 学术定义: 新建一个维度,并沿着这个新维度把张量组合起来。
  • 通俗解释: 不是把原来的方向接长,而是把多个对象叠成一摞,额外多出一层。
  • 进阶例子 (RL Experience Replay): 场景: 强化学习中,从 Replay Buffer 取出多个时间步的状态,并把它们组织成一个 Batch。

2.2.5 总结对比表

函数 操作本质 典型学术场景 (LLM/RL)
view / reshape 形状重排 (元素总量不变) Hidden_Dim 切成 Num_Heads * Head_Dim
flatten 维度合并 (多维压成一维) CNN 特征进入 MLP 决策层之前的整理
permute 整体换序 (任意调整维度位置) (B, L, H, D) 调整为 (B, H, L, D) 来并行计算 Attention
transpose 成对交换 (只换两个维度) 为 Attention Score 计算准备矩阵转置
unsqueeze 插入维度 (加入大小为 1 的轴) 构造可广播到多头 Attention 矩阵的 Mask
cat 沿旧轴连接 (拉长已有维度) 推理阶段更新 KV Cache;融合多模态特征
stack 沿新轴组合 (新增一层维度) 把多个时间步的 State 组织成 Batch

二次使用