手撕 MultiHeadAttention
手撕 Multi-Head Attention
1. 概述
Multi-Head Attention 是 Transformer 模型的核心组件,其作用是让模型同时关注序列中不同位置的信息,并通过“多头”机制捕捉不同子空间的特征依赖,大幅提升模型的表达能力。
2. 核心原理
2.1 自注意力(Self-Attention)基础
自注意力的核心公式为:
其中:
- (Query)、(Key)、(Value)是输入的线性投影;
- 是“缩放点积”,用于计算位置间的注意力分数;
- 将分数归一化为权重,最终通过权重聚合 得到输出。
2.2 多头注意力的设计思路
多头注意力将 拆分为多个“头”(Head),每个头独立计算注意力,最后拼接结果:
其中 , 是头数,(确保维度可均分)。
3. 代码逐行解析
3.1 依赖导入(补充)
代码需先导入 PyTorch 相关模块:
1 | |
3.2 __init__ 方法:参数与层初始化
1 | |
参数与变量意义:
d_model:模型隐藏层维度(如 Transformer-base 为 512),输入/输出的特征维度;num_heads:注意力头数(如 Transformer-base 为 8),需保证d_model % num_heads == 0;self.d_k:每个头的维度,即d_model // num_heads;self.W_q/self.W_k/self.W_v:将输入投影为 、、 的线性层;self.W_o:融合多头结果并投影回d_model的输出线性层。
3.3 forward 方法:前向传播全流程
1 | |
步骤1:生成 Q、K、V 并分头
假设输入 x 形状为 (batch_size, seq_len, d_model),则:
- 线性投影:
self.W_q(x)将x投影为同维度的 (、 同理); - 分头:
view(..., num_heads, d_k)将d_model拆分为num_heads * d_k,形状变为(batch_size, seq_len, num_heads, d_k); - 维度置换:
transpose(1, 2)将num_heads维度提前,形状变为(batch_size, num_heads, seq_len, d_k),方便后续每个头独立计算。
步骤2:计算缩放点积注意力分数
- 点积计算:
Q @ K.transpose(-2, -1)中, 的最后两维置换为(d_k, seq_len),因此 (形状(batch_size, num_heads, seq_len, d_k))与 转置相乘后,形状为(batch_size, num_heads, seq_len, seq_len),表示每个位置对其他位置的注意力分数; - 缩放:除以
√d_k是为了防止点积值过大导致 softmax 进入梯度饱和区(稳定方差为 1)。
步骤3:应用 Mask(可选)
- 作用:屏蔽无效位置(如 Padding Mask 屏蔽填充 token,Look-ahead Mask 屏蔽未来 token);
- 操作:将
mask为True的位置填充为-1e9,使 softmax 后这些位置的权重接近 0; - 形状要求:
mask需能广播到attn_score的形状(batch_size, num_heads, seq_len, seq_len)(如 Padding Mask 通常为(batch_size, 1, 1, seq_len))。
步骤4:归一化权重并聚合 V
- Softmax 归一化:
F.softmax(attn_score, dim=-1)在最后一维归一化,得到注意力权重attn_weights(形状同上,每行权重和为 1); - 聚合 V:
attn_weights @ V将权重与 (形状(batch_size, num_heads, seq_len, d_k))相乘,得到聚合输出attn_output,形状为(batch_size, num_heads, seq_len, d_k)。
步骤5:多头拼接与输出投影
- 维度还原:
transpose(1, 2)将num_heads维度换回,形状变为(batch_size, seq_len, num_heads, d_k);contiguous()确保内存连续(transpose会导致内存不连续,view需连续内存);view(..., -1)将num_heads * d_k拼接回d_model,形状变为(batch_size, seq_len, d_model); - 输出投影:
self.W_o(attn_output)通过线性层融合多头信息,最终输出形状为(batch_size, seq_len, d_model)。
4. 数据形状变化总结表
| 步骤 | 操作 | 形状变化 |
|---|---|---|
| 输入 | x |
(batch_size, seq_len, d_model) |
| 线性投影 | self.W_q(x) |
(batch_size, seq_len, d_model) |
| 分头 | view(..., num_heads, d_k) |
(batch_size, seq_len, num_heads, d_k) |
| 维度置换 | transpose(1, 2) |
(batch_size, num_heads, seq_len, d_k) |
| 点积计算 | Q @ K.transpose(-2, -1) |
(batch_size, num_heads, seq_len, seq_len) |
| 缩放 | / sqrt(d_k) |
(batch_size, num_heads, seq_len, seq_len) |
| Mask(可选) | masked_fill |
(batch_size, num_heads, seq_len, seq_len) |
| Softmax | F.softmax |
(batch_size, num_heads, seq_len, seq_len) |
| 聚合 V | attn_weights @ V |
(batch_size, num_heads, seq_len, d_k) |
| 维度还原 | transpose(1, 2).view |
(batch_size, seq_len, d_model) |
| 输出投影 | self.W_o |
(batch_size, seq_len, d_model) |
5. 总结
Multi-Head Attention 的核心流程可概括为:
- 线性投影生成 ;
- 分头并调整维度;
- 缩放点积计算注意力分数;
- 应用 Mask(可选);
- Softmax 归一化并聚合 ;
- 多头拼接并输出投影。
通过“分头计算-拼接融合”,模型能同时捕捉序列中不同子空间的依赖关系,是 Transformer 强大表达能力的关键来源。
手撕 MultiHeadAttention
https://huan-yin.github.io/2026/04/09/手撕MultiHeadAttention/