手撕 MultiHeadAttention

手撕 Multi-Head Attention

1. 概述

Multi-Head Attention 是 Transformer 模型的核心组件,其作用是让模型同时关注序列中不同位置的信息,并通过“多头”机制捕捉不同子空间的特征依赖,大幅提升模型的表达能力。

2. 核心原理

2.1 自注意力(Self-Attention)基础

自注意力的核心公式为:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) V

其中:

  • QQ(Query)、KK(Key)、VV(Value)是输入的线性投影;
  • QKTdk\frac{Q K^T}{\sqrt{d_k}} 是“缩放点积”,用于计算位置间的注意力分数;
  • softmax\text{softmax} 将分数归一化为权重,最终通过权重聚合 VV 得到输出。

2.2 多头注意力的设计思路

多头注意力将 Q,K,VQ, K, V 拆分为多个“头”(Head),每个头独立计算注意力,最后拼接结果:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W_O

其中 headi=Attention(QWQi,KWKi,VWVi)\text{head}_i = \text{Attention}(Q W_Q^i, K W_K^i, V W_V^i)hh 是头数,dk=dmodelhd_k = \frac{d_{\text{model}}}{h}(确保维度可均分)。

3. 代码逐行解析

3.1 依赖导入(补充)

代码需先导入 PyTorch 相关模块:

1
2
3
import torch
import torch.nn as nn
import torch.nn.functional as F

3.2 __init__ 方法:参数与层初始化

1
2
3
4
5
6
7
8
9
10
11
def __init__(self, d_model: int, num_heads: int):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_k = d_model // num_heads # 每个头的维度
self.num_heads = num_heads # 注意力头数

# 线性投影层:生成 Q、K、V 和最终输出
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

参数与变量意义

  • d_model:模型隐藏层维度(如 Transformer-base 为 512),输入/输出的特征维度;
  • num_heads:注意力头数(如 Transformer-base 为 8),需保证 d_model % num_heads == 0
  • self.d_k:每个头的维度,即 d_model // num_heads
  • self.W_q/self.W_k/self.W_v:将输入投影为 QQKKVV 的线性层;
  • self.W_o:融合多头结果并投影回 d_model 的输出线性层。

3.3 forward 方法:前向传播全流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
batch_size, seq_len, _ = x.size() # 获取输入维度

# 步骤1:生成 Q、K、V 并分头
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

# 步骤2:计算缩放点积注意力分数
attn_score = Q @ K.transpose(-2, -1) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))

# 步骤3:应用 Mask(可选)
if mask is not None:
attn_score = attn_score.masked_fill(mask, -1e9)

# 步骤4:归一化权重并聚合 V
attn_weights = F.softmax(attn_score, dim=-1)
attn_output = attn_weights @ V

# 步骤5:多头拼接与输出投影
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
output = self.W_o(attn_output)

return output

步骤1:生成 Q、K、V 并分头

假设输入 x 形状为 (batch_size, seq_len, d_model),则:

  • 线性投影self.W_q(x)x 投影为同维度的 QQKKVV 同理);
  • 分头view(..., num_heads, d_k)d_model 拆分为 num_heads * d_k,形状变为 (batch_size, seq_len, num_heads, d_k)
  • 维度置换transpose(1, 2)num_heads 维度提前,形状变为 (batch_size, num_heads, seq_len, d_k),方便后续每个头独立计算。

步骤2:计算缩放点积注意力分数

  • 点积计算Q @ K.transpose(-2, -1) 中,KK 的最后两维置换为 (d_k, seq_len),因此 QQ(形状 (batch_size, num_heads, seq_len, d_k))与 KK 转置相乘后,形状为 (batch_size, num_heads, seq_len, seq_len),表示每个位置对其他位置的注意力分数;
  • 缩放:除以 √d_k 是为了防止点积值过大导致 softmax 进入梯度饱和区(稳定方差为 1)。

步骤3:应用 Mask(可选)

  • 作用:屏蔽无效位置(如 Padding Mask 屏蔽填充 token,Look-ahead Mask 屏蔽未来 token);
  • 操作:将 maskTrue 的位置填充为 -1e9,使 softmax 后这些位置的权重接近 0;
  • 形状要求mask 需能广播到 attn_score 的形状 (batch_size, num_heads, seq_len, seq_len)(如 Padding Mask 通常为 (batch_size, 1, 1, seq_len))。

步骤4:归一化权重并聚合 V

  • Softmax 归一化F.softmax(attn_score, dim=-1) 在最后一维归一化,得到注意力权重 attn_weights(形状同上,每行权重和为 1);
  • 聚合 Vattn_weights @ V 将权重与 VV(形状 (batch_size, num_heads, seq_len, d_k))相乘,得到聚合输出 attn_output,形状为 (batch_size, num_heads, seq_len, d_k)

步骤5:多头拼接与输出投影

  • 维度还原transpose(1, 2)num_heads 维度换回,形状变为 (batch_size, seq_len, num_heads, d_k)contiguous() 确保内存连续(transpose 会导致内存不连续,view 需连续内存);view(..., -1)num_heads * d_k 拼接回 d_model,形状变为 (batch_size, seq_len, d_model)
  • 输出投影self.W_o(attn_output) 通过线性层融合多头信息,最终输出形状为 (batch_size, seq_len, d_model)

4. 数据形状变化总结表

步骤 操作 形状变化
输入 x (batch_size, seq_len, d_model)
线性投影 self.W_q(x) (batch_size, seq_len, d_model)
分头 view(..., num_heads, d_k) (batch_size, seq_len, num_heads, d_k)
维度置换 transpose(1, 2) (batch_size, num_heads, seq_len, d_k)
点积计算 Q @ K.transpose(-2, -1) (batch_size, num_heads, seq_len, seq_len)
缩放 / sqrt(d_k) (batch_size, num_heads, seq_len, seq_len)
Mask(可选) masked_fill (batch_size, num_heads, seq_len, seq_len)
Softmax F.softmax (batch_size, num_heads, seq_len, seq_len)
聚合 V attn_weights @ V (batch_size, num_heads, seq_len, d_k)
维度还原 transpose(1, 2).view (batch_size, seq_len, d_model)
输出投影 self.W_o (batch_size, seq_len, d_model)

5. 总结

Multi-Head Attention 的核心流程可概括为:

  1. 线性投影生成 Q,K,VQ, K, V
  2. 分头并调整维度;
  3. 缩放点积计算注意力分数;
  4. 应用 Mask(可选);
  5. Softmax 归一化并聚合 VV
  6. 多头拼接并输出投影。

通过“分头计算-拼接融合”,模型能同时捕捉序列中不同子空间的依赖关系,是 Transformer 强大表达能力的关键来源。


手撕 MultiHeadAttention
https://huan-yin.github.io/2026/04/09/手撕MultiHeadAttention/
作者
李相越
发布于
2026年4月9日
许可协议