Chapter 2.2 PyTorch 维度变换操作函数
2.2.1 重塑类 (Reshaping)
核心函数: view, reshape, flatten
2.2.1.1 torch.view / torch.reshape
- 学术定义: 在元素数量保持一致的前提下,重新指定张量的形状(Shape)。
view要求底层内存连续(Contiguous),reshape会在必要时自行处理拷贝。 - 通俗解释: 就像同一团橡皮泥换个外形,材料没少,只是从“长条”变成了“方块”。
- 简单例子 (理解语法): 把一个包含 12 个元素的一维向量整理成一个矩阵。
x = torch.arange(12) # shape: (12)
y = x.view(3, 4) # shape: (3, 4)- 进阶例子 (LLM Multi-head Attention): 场景: Transformer 里经常需要把隐藏层维度切分成多个注意力头(Heads)。 假设
batch_size=32,seq_len=128,hidden_dim=768,也就是 12 个头,每个头 64 维。
# Input: [Batch, Seq_Len, Hidden_Dim]
x = torch.randn(32, 128, 768)
# Output: [Batch, Seq_Len, Num_Heads, Head_Dim]
# 使用 view 将最后一维拆开
x_heads = x.view(32, 128, 12, 64)2.2.1.2 torch.flatten
- 学术定义: 把一段连续维度合并成一个维度。
- 通俗解释: 像把多层结构压成一张平面,层数不见了,但内容还在。
- 简单例子 (理解语法):
x = torch.randn(2, 3, 4)
# 把后面两维拍扁: (2, 3*4) -> (2, 12)
y = x.flatten(start_dim = 1)- 进阶例子 (CNN/RL 策略网络): 场景: 强化学习里处理图像观测时,卷积得到的 Feature Map 通常要先摊平成向量,再送入全连接层。
# Input: [Batch, Channels, Height, Width] -> [32, 64, 7, 7]
features = torch.randn(32, 64, 7, 7)
# Flatten relevant dims for Linear Layer input
# Output: [Batch, Features] -> [32, 64*7*7] -> [32, 3136]
flat_features = features.flatten(start_dim = 1)2.2.2 交换类 (Swapping)
核心函数: permute, transpose
2.2.2.1 torch.permute
- 学术定义: 根据给定的维度索引顺序,对所有维度重新排序。
- 通俗解释: 相当于给每个维度重新排座位,哪一维放前面、哪一维放后面都由你指定。
- 简单例子 (理解语法):
x = torch.randn(2, 3, 4)
# 变成 (4, 2, 3) -> 原来的第2维放最前,第0维放中间,第1维放最后
y = x.permute(2, 0, 1)- 进阶例子 (Transformer 维度调整): 场景: 在 Multi-head Attention 中,为了按 head 并行做矩阵乘法,常常需要把
Num_Heads放到Seq_Len前面。
# 接上面的 view 例子: [Batch, Seq_Len, Num_Heads, Head_Dim]
x_heads = torch.randn(32, 128, 12, 64)
# 目标: [Batch, Num_Heads, Seq_Len, Head_Dim]
# 交换维度1和维度2
query = x_heads.permute(0, 2, 1, 3)2.2.2.2 torch.transpose
- 学术定义: 只对指定的两个维度进行互换。
- 通俗解释: 它不像
permute那样能重新排列全部维度,但在只换一对维度时更直接。 - 进阶例子 (矩阵乘法准备): 场景: 计算 Attention Score 时,Key 张量通常要交换最后两个维度,才能和 Query 对齐做乘法。
# Key: [Batch, Heads, Seq_Len, Head_Dim]
K = torch.randn(32, 12, 128, 64)
# Transpose last two dims for matmul: [Batch, Heads, Head_Dim, Seq_Len]
K_t = K.transpose(-1, -2)2.2.3 增减类 (Adding/Removing)
核心函数: unsqueeze, squeeze
- 简单例子:
unsqueeze插入一个大小为 1 的维度,squeeze则把这类单元素维度去掉。 - 进阶例子 (广播机制 Broadcasting): 场景: 大模型里使用 Attention Mask 时,Mask 往往是 2D 形状,但实际要和 4D 的 Attention Score 相加。
# Scores: [Batch, Heads, Seq_Len, Seq_Len]
scores = torch.randn(32, 12, 128, 128)
# Mask: [Batch, Seq_Len] (比如 padding mask)
mask = torch.ones(32, 128)
# 为了让 Mask 能加到 Scores 上,需要扩展维度: [Batch, 1, 1, Seq_Len]
# 在第1维和第2维插入维度
mask_expanded = mask.unsqueeze(1).unsqueeze(2)
# 此时 mask_expanded 可以通过广播机制与 scores 相加2.2.4 拼接类 (Combining)
核心函数: cat (concatenate), stack
2.2.4.1 torch.cat
- 学术定义: 沿已有维度把多个张量接在一起。
- 通俗解释: 像把几段绳子首尾相接,长度增加,但方向还是原来的方向。
- 进阶例子 (KV Cache 推理加速): 场景: LLM 自回归生成时,会把新 token 对应的 Key/Value 接到已有缓存后面。
# Cache: [Batch, Heads, Past_Len, Head_Dim]
past_key = torch.randn(1, 12, 50, 64)
# New Token Key: [Batch, Heads, 1, Head_Dim]
new_key = torch.randn(1, 12, 1, 64)
# 在 Sequence 维度 (dim=2) 拼接 -> [1, 12, 51, 64]
current_key = torch.cat([past_key, new_key], dim = 2)2.2.4.2 torch.stack
- 学术定义: 新建一个维度,并沿着这个新维度把张量组合起来。
- 通俗解释: 不是把原来的方向接长,而是把多个对象叠成一摞,额外多出一层。
- 进阶例子 (RL Experience Replay): 场景: 强化学习中,从 Replay Buffer 取出多个时间步的状态,并把它们组织成一个 Batch。
2.2.5 总结对比表
| 函数 | 操作本质 | 典型学术场景 (LLM/RL) |
|---|---|---|
view / reshape |
形状重排 (元素总量不变) | 把 Hidden_Dim 切成 Num_Heads * Head_Dim |
flatten |
维度合并 (多维压成一维) | CNN 特征进入 MLP 决策层之前的整理 |
permute |
整体换序 (任意调整维度位置) | 将 (B, L, H, D) 调整为 (B, H, L, D) 来并行计算 Attention |
transpose |
成对交换 (只换两个维度) | 为 Attention Score 计算准备矩阵转置 |
unsqueeze |
插入维度 (加入大小为 1 的轴) | 构造可广播到多头 Attention 矩阵的 Mask |
cat |
沿旧轴连接 (拉长已有维度) | 推理阶段更新 KV Cache;融合多模态特征 |
stack |
沿新轴组合 (新增一层维度) | 把多个时间步的 State 组织成 Batch |