AI Infra学习之旅-Transformer为什么不只是一个Attention公式

这是“Transformer 原理深讲系列”的第 3 篇。
上一篇我们已经把位置编码、Q/K/V 和一次单头 Attention 的计算过程讲清楚了。
但如果你继续往 Transformer 的真实结构里走,很快就会发现:
Transformer 远不止一个 Attention 公式。
这一篇就集中讲清楚三个最容易被“背概念”但最不该只停留在概念层面的模块:
Multi-Head Attention、Mask、FFN。


一、为什么单头 Attention 还不够

上一讲里,我们已经知道,一次单头 Attention 的核心形式是:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

它做的事情可以概括为:

  1. 用 Query 和 Key 计算“谁值得关注”
  2. 把相关位置的 Value 按权重聚合起来
  3. 为每个 token 生成一个新的上下文化表示

这已经很强了,因为它允许:

  • 每个位置直接看全局
  • 长距离依赖不再必须沿时间链条传递
  • 上下文信息可以动态加权融合

但问题在于:

语言关系并不是单一的。

同一个 token,在同一个句子里,往往同时需要处理很多种关系:

  • 局部邻近关系
  • 主谓宾关系
  • 修饰关系
  • 指代关系
  • 长程依赖
  • 主题与语义聚类关系

如果只有一个头,那么这一切都必须挤进同一套相关性空间、同一张注意力图里。

这会产生一个很自然的表达瓶颈:

一个头必须同时承担太多不同类型的关系建模任务。

所以,单头 attention 的问题并不是“不能看全局”,而是:

它只能用一个视角看全局。


二、Multi-Head Attention 的本质:不是重复做很多次,而是多视角建模

Transformer 的解决方法是:

不要只做一次 Attention,而是并行做多次,每次都在不同的表示子空间中完成。

如果有 (h) 个头,那么第 (i) 个头都有自己独立的参数:

WiQ,WiK,WiVW_i^Q,\quad W_i^K,\quad W_i^V

于是对同一个输入 (H),每个头都会产生不同的投影:

Qi=HWiQ,Ki=HWiK,Vi=HWiVQ_i = HW_i^Q,\quad K_i = HW_i^K,\quad V_i = HW_i^V

然后每个头单独做一次 attention:

headi=Attention(Qi,Ki,Vi)\text{head}_i = \text{Attention}(Q_i,K_i,V_i)

最后把所有头拼接起来:

MultiHead(H)=Concat(head1,,headh)WO\text{MultiHead}(H)=\text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O

这里 (W^O) 是输出投影矩阵,用来把多个头的结果重新融合回主表示空间。

所以,Multi-Head Attention 真正“多”的地方,不是算得多,而是:

  • 每个头看到的是不同的投影空间
  • 每个头使用的是不同的匹配标准
  • 每个头学到的是不同关系模式下的上下文化结果

这正是多头机制的关键。


三、为什么语言需要“多头”,而不是“一个更大的头”

这是一个很值得深究的问题。

你可能会想:
既然单头不够,为什么不直接把单头维度做大,而要拆成多个头?

答案在于:

一个更大的头,仍然只有一个统一的相关性视角;而多个头,可以真正实现关系分工。

例如,在实际训练中,不同头往往会自发偏向不同模式:

  • 某些头更关注局部邻接
  • 某些头更关注句法边界
  • 某些头更容易捕捉指代
  • 某些头更偏向全局主题聚合

虽然这不是人工硬编码的,但多头结构给了模型这种能力空间。

从几何角度看,单头像是在一个坐标系中看所有关系;
多头则像是在多个不同坐标系中同时观察同一组 token。

所以多头的真正价值,不是“更宽”,而是:

同一层内的多子空间关系分解能力。


四、Multi-Head Attention 在计算上到底做了什么

设输入是:

HRn×dmodelH \in \mathbb{R}^{n \times d_{\text{model}}}

如果有 (h) 个头,那么通常会把每个头的维度设成:

dk=dv=dmodelhd_k = d_v = \frac{d_{\text{model}}}{h}

例如:

  • (d_{\text{model}} = 512)
  • 头数 (h = 8)
  • 那么每个头维度就是 64

第 (i) 个头做:

Qi=HWiQ,Ki=HWiK,Vi=HWiVQ_i = HW_i^Q,\quad K_i = HW_i^K,\quad V_i = HW_i^V

headi=softmax(QiKiTdk)Vi\text{head}_i=\text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)V_i

最终所有头拼接:

Z=Concat(head1,,headh)Z=\text{Concat}(\text{head}_1,\dots,\text{head}_h)

再通过输出矩阵:

MultiHead(H)=ZWO\text{MultiHead}(H)=ZW^O

这里最后的 (W^O) 不是装饰,而是在做一件很关键的事:

把不同头的独立视角重新整合成统一输出。

所以整个过程不是:

  • 多个头各自结束就完了

而是:

  • 多个头各自观察
  • 再通过 (W^O) 融合成下一层真正使用的表示

五、为什么说“一个头看语法,一个头看指代”这类说法有道理,但不能说死

这类说法在科普里很常见,它有一定道理,但需要更精确理解。

有道理的地方

不同头参数不同,所以它们确实可能学出不同关系模式。
在很多可视化研究中,人们也确实观察到:

  • 有些头更关注相邻词
  • 有些头更关注标点或句法边界
  • 有些头对长距离指代敏感

不能说死的地方

但这不意味着某个头有一个永恒固定的人类语法标签。因为:

  1. 头的功能不是人工预设的
  2. 同一个头可能混合承担多种关系
  3. 不同模型、不同层、不同训练阶段分工都可能不同
  4. 有些头甚至可能逐渐退化为简单模式

所以更准确的说法应该是:

Multi-Head Attention 允许模型在多个子空间中学习不同类型的关系模式,而训练后这些模式常常表现出一定分工。


六、现在进入第二个核心主题:Mask 到底是什么

如果说多头是在回答:

“为什么不能只用一个视角看上下文?”

那么 Mask 回答的是另一个问题:

“为什么不是所有位置都允许互相看?”

这一步非常关键,因为它决定了 Transformer 到底是在做:

  • 双向理解
    还是
  • 自回归生成

我们先回忆一下注意力分数矩阵:

S=QKTdkS = \frac{QK^T}{\sqrt{d_k}}

如果不加任何约束,那么每个位置都可以看见所有位置。
也就是说,信息流是全连接的。

这对某些任务是好事,例如句子理解。
但对生成任务来说,会直接出问题。


七、为什么 Decoder 必须有 causal mask

对于 Decoder-only 语言模型,训练目标是:

P(xtx1,,xt1)P(x_t \mid x_1,\dots,x_{t-1})

也就是说,第 (t) 个 token 的预测只能依赖前缀,不能看未来。

如果不加 mask,那么第 (t) 个位置在 Self-Attention 里可以直接看到:

  • 第 (t+1) 个 token
  • 第 (t+2) 个 token
  • ……

这意味着模型在训练时会“偷看答案”。
那它学到的就不是自回归语言建模,而是某种带未来信息的条件建模。

所以 causal mask 的本质不是训练技巧,而是:

强制 Self-Attention 的信息流遵守因果方向。


八、causal mask 的矩阵形式

假设序列长度是 4。
不加 mask 时,score matrix 是一个完整 (4 \times 4) 矩阵:

S=[s11s12s13s14s21s22s23s24s31s32s33s34s41s42s43s44]S= \begin{bmatrix} s_{11} & s_{12} & s_{13} & s_{14}\\ s_{21} & s_{22} & s_{23} & s_{24}\\ s_{31} & s_{32} & s_{33} & s_{34}\\ s_{41} & s_{42} & s_{43} & s_{44} \end{bmatrix}

如果加入 causal mask,那么可见性就必须变成:

[1000110011101111]\begin{bmatrix} 1 & 0 & 0 & 0\\ 1 & 1 & 0 & 0\\ 1 & 1 & 1 & 0\\ 1 & 1 & 1 & 1 \end{bmatrix}

换句话说:

  • 第 1 个位置只能看自己
  • 第 2 个位置只能看前两个位置
  • 第 3 个位置只能看前三个位置
  • 第 4 个位置可以看前四个位置

所以 causal mask 常常是一个下三角结构


九、Mask 是怎么真正加进 Attention 里的

mask 的实现方式不是把未来位置“删掉”,而是在 softmax 前给非法位置加一个很大的负数:

S^=S+M\hat{S}=S+M

其中 (M) 是 mask 矩阵。
对不允许看的位置,(M) 的对应元素通常是:

-\infty

或实现中用一个极小值,如 (-10^9)。

然后做 softmax:

A=softmax(S^)A = \text{softmax}(\hat{S})

由于:

e=0e^{-\infty}=0

所以这些非法位置在 softmax 后的权重就会变成 0。

这说明:

mask 真正做的事情,是在注意力分配前就把非法位置剔除出概率归一化。

这比 softmax 后再置零更合理,因为这样剩余可见位置的权重仍然保持总和为 1。


十、padding mask 和 causal mask 不是一回事

很多人一听到 mask,就只想到 causal mask。
其实在 Transformer 里常见的 mask 至少有两类:

1. padding mask

用于忽略补齐(PAD)位置。
因为 batch 里不同样本长度不同,通常要 pad 到相同长度。
这些 PAD token 不应该参与注意力。

2. causal mask

用于阻止当前位置看未来位置。
这是 Decoder-only 自回归生成成立的必要条件。

所以:

  • padding mask 解决的是“哪些位置根本不是有效 token”
  • causal mask 解决的是“哪些位置虽然有效,但在因果方向上不允许看”

二者经常一起使用,但角色完全不同。


十一、为什么说 Mask 在改写信息流拓扑

如果从图结构角度看:

  • 不加 mask 时,attention 图是全连接图
  • 加 causal mask 后,图会变成一个有方向的下三角可达图

换句话说:

Mask 不只是数值层面的遮罩,它实际上改变了 Transformer 中信息传播的合法路径。

这一点很重要,因为它说明:

  • Encoder 之所以适合理解,是因为它允许双向信息流
  • Decoder 之所以适合生成,是因为 causal mask 把信息流限制成过去 (\rightarrow) 未来

所以 mask 本质上是在 Transformer 中显式编码“允许的信息流结构”。


十二、现在进入第三个核心主题:FFN 到底是什么

Attention 经常是 Transformer 里最被关注的部分,但实际上 Transformer block 不是只有 attention。

每个 block 通常还有一个非常重要的模块:

FFN(Feed Forward Network)

它的经典形式是:

FFN(x)=W2σ(W1x+b1)+b2\text{FFN}(x)=W_2 \sigma(W_1x+b_1)+b_2

这里:

  • (x \in \mathbb{R}^{d_{\text{model}}})
  • (W_1 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}})
  • (W_2 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}})
  • (\sigma) 是非线性激活,例如 ReLU、GELU、SwiGLU 等

FFN 看起来像一个非常普通的两层 MLP。
但在 Transformer 里,它承担的是一个 attention 无法替代的角色。


十三、为什么 Attention 后面还必须接 FFN

这是 FFN 存在的根本原因。

Attention 的本质是:

oi=jaijvjo_i = \sum_j a_{ij}v_j

也就是说,它更像是一个:

  • 全局相关性计算器
  • 动态加权聚合器
  • 信息交换与路由机制

它擅长的是:

  • 决定谁该看谁
  • 把全局信息按权重混合进来

但它不擅长的是:

对聚合后的单个位置表示做足够强的逐位置非线性加工。

所以 Transformer 不能只有 attention。
否则它会更像一个“上下文混合系统”,而不是一个强表达能力的深层网络。

FFN 的作用就是:

在 attention 已经完成信息交换后,对每个位置自己的表示再做一次高维非线性变换。


十四、FFN 为什么通常是“先升维再降维”

FFN 通常采用:

dmodeldffdmodeld_{\text{model}} \rightarrow d_{\text{ff}} \rightarrow d_{\text{model}}

其中 (d_{\text{ff}}) 往往比 (d_{\text{model}}) 大很多,例如 4 倍。

为什么这么设计?

1. 提升表示能力

先把向量投影到更高维空间,相当于给模型更多中间特征通道,用来表达更复杂的组合模式。

2. 在高维空间中做非线性展开

更高维意味着:

  • 可以容纳更多特征方向
  • 可以形成更细致的激活模式
  • 可以让语义、句法、主题等因素在不同通道上被区分和重组

3. 再压回模型维度

把加工后的高维特征重新整合成统一主空间表示,供下一层继续使用。

所以 FFN 的核心逻辑是:

低维输入 → 高维展开 → 非线性加工 → 压回主维度


十五、为什么 FFN 是逐位置独立做的

这点也很关键。

设 attention 输出后的序列表示是:

HRn×dmodelH' \in \mathbb{R}^{n \times d_{\text{model}}}

FFN 会对每个位置单独作用:

FFN(hi)=W2σ(W1hi+b1)+b2\text{FFN}(h'_i)=W_2 \sigma(W_1h'_i+b_1)+b_2

也就是说,它不再让位置之间互相看,而是:

对每个 token 自己的向量表示做同一套非线性变换。

为什么这么设计?

因为 attention 已经完成了跨位置的信息交互。
此时每个位置的向量里,已经包含了所需的上下文信息。
接下来更合理的事,就是让每个位置:

  • 自己消化上下文
  • 自己做更深层特征重组
  • 自己变成更适合下一层使用的表示

所以:

  • attention 负责“对外交流”
  • FFN 负责“对内加工”

这两者职责不同,但缺一不可。


十六、从几何和实际意义上理解 FFN

如果把 Attention 理解成“建立关系并进行全局加权”,
那么 FFN 更像:

对单个位置的高维特征再加工。

它的第一层 (W_1) 可以看成一组 learned feature detectors:

  • 某些方向对语法模式敏感
  • 某些方向对语义模式敏感
  • 某些方向对主题结构敏感

非线性激活相当于一个门:

  • 重要模式被激活
  • 无关模式被压制

第二层 (W_2) 再把激活后的高维特征重新组合,形成新的输出表示。

所以 FFN 不只是“再接一层 MLP”,而是在承担:

逐位置表示增强器 的角色。


十七、把这一篇真正串起来看:Transformer 为什么不只是一个 Attention 公式

现在你应该能真正理解,Transformer 的 block 远不只是:

softmax(QKTdk)V\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

因为如果只有这个公式,会缺少很多关键能力:

缺少多头

就只能用一个视角看上下文。

缺少 mask

就无法区分“理解模型”和“生成模型”,Decoder 也会偷看未来。

缺少 FFN

就只能做信息混合,缺少逐位置非线性加工能力。

所以真正的 Transformer 是:

  • 多头 attention:多视角建模
  • mask:控制信息流拓扑
  • FFN:增强逐位置表示能力

这三者共同构成了 Transformer Block 内部最关键的表达骨架。


十八、本篇真正要记住的三条主线

第一条:为什么要多头

  • 单头 attention 只有一个相关性视角
  • 语言关系多种多样
  • 多头允许模型在多个子空间中并行学习不同类型的关系模式

第二条:为什么需要 mask

  • Self-Attention 原本允许任意位置看任意位置
  • 生成模型必须遵守因果方向
  • causal mask 用来阻止当前位置偷看未来
  • padding mask 则用来屏蔽无效填充位置

第三条:FFN 为什么不能省

  • attention 负责信息交换
  • FFN 负责逐位置非线性加工
  • 两者职责完全不同
  • FFN 通过“先升维再降维”提升表示能力

十九、用一句话压缩本篇

Transformer 之所以不只是一个 Attention 公式,是因为它不仅需要通过 Multi-Head Attention 在多个子空间中建模不同关系,还必须通过 Mask 控制合法的信息流方向,并通过 FFN 对每个位置的聚合结果做进一步的高维非线性加工。


下一篇预告

下一篇主题是:

一层 Transformer Block 到底是什么:残差、LayerNorm 与 Encoder/Decoder 全貌

会重点讲清楚:

  • 为什么深层网络一定需要残差连接
  • LayerNorm 到底在归一化什么
  • Pre-LN 和 Post-LN 有什么差别
  • 一层 Transformer block 从输入到输出到底经历几步
  • 原始 Transformer 为什么是 Encoder-Decoder
  • Encoder-only、Decoder-only、Encoder-Decoder 分别适合什么任务