AI Infra学习之旅-Transformer为什么不只是一个Attention公式
AI Infra学习之旅-Transformer为什么不只是一个Attention公式
这是“Transformer 原理深讲系列”的第 3 篇。
上一篇我们已经把位置编码、Q/K/V 和一次单头 Attention 的计算过程讲清楚了。
但如果你继续往 Transformer 的真实结构里走,很快就会发现:
Transformer 远不止一个 Attention 公式。
这一篇就集中讲清楚三个最容易被“背概念”但最不该只停留在概念层面的模块:
Multi-Head Attention、Mask、FFN。
一、为什么单头 Attention 还不够
上一讲里,我们已经知道,一次单头 Attention 的核心形式是:
它做的事情可以概括为:
- 用 Query 和 Key 计算“谁值得关注”
- 把相关位置的 Value 按权重聚合起来
- 为每个 token 生成一个新的上下文化表示
这已经很强了,因为它允许:
- 每个位置直接看全局
- 长距离依赖不再必须沿时间链条传递
- 上下文信息可以动态加权融合
但问题在于:
语言关系并不是单一的。
同一个 token,在同一个句子里,往往同时需要处理很多种关系:
- 局部邻近关系
- 主谓宾关系
- 修饰关系
- 指代关系
- 长程依赖
- 主题与语义聚类关系
如果只有一个头,那么这一切都必须挤进同一套相关性空间、同一张注意力图里。
这会产生一个很自然的表达瓶颈:
一个头必须同时承担太多不同类型的关系建模任务。
所以,单头 attention 的问题并不是“不能看全局”,而是:
它只能用一个视角看全局。
二、Multi-Head Attention 的本质:不是重复做很多次,而是多视角建模
Transformer 的解决方法是:
不要只做一次 Attention,而是并行做多次,每次都在不同的表示子空间中完成。
如果有 (h) 个头,那么第 (i) 个头都有自己独立的参数:
于是对同一个输入 (H),每个头都会产生不同的投影:
然后每个头单独做一次 attention:
最后把所有头拼接起来:
这里 (W^O) 是输出投影矩阵,用来把多个头的结果重新融合回主表示空间。
所以,Multi-Head Attention 真正“多”的地方,不是算得多,而是:
- 每个头看到的是不同的投影空间
- 每个头使用的是不同的匹配标准
- 每个头学到的是不同关系模式下的上下文化结果
这正是多头机制的关键。
三、为什么语言需要“多头”,而不是“一个更大的头”
这是一个很值得深究的问题。
你可能会想:
既然单头不够,为什么不直接把单头维度做大,而要拆成多个头?
答案在于:
一个更大的头,仍然只有一个统一的相关性视角;而多个头,可以真正实现关系分工。
例如,在实际训练中,不同头往往会自发偏向不同模式:
- 某些头更关注局部邻接
- 某些头更关注句法边界
- 某些头更容易捕捉指代
- 某些头更偏向全局主题聚合
虽然这不是人工硬编码的,但多头结构给了模型这种能力空间。
从几何角度看,单头像是在一个坐标系中看所有关系;
多头则像是在多个不同坐标系中同时观察同一组 token。
所以多头的真正价值,不是“更宽”,而是:
同一层内的多子空间关系分解能力。
四、Multi-Head Attention 在计算上到底做了什么
设输入是:
如果有 (h) 个头,那么通常会把每个头的维度设成:
例如:
- (d_{\text{model}} = 512)
- 头数 (h = 8)
- 那么每个头维度就是 64
第 (i) 个头做:
最终所有头拼接:
再通过输出矩阵:
这里最后的 (W^O) 不是装饰,而是在做一件很关键的事:
把不同头的独立视角重新整合成统一输出。
所以整个过程不是:
- 多个头各自结束就完了
而是:
- 多个头各自观察
- 再通过 (W^O) 融合成下一层真正使用的表示
五、为什么说“一个头看语法,一个头看指代”这类说法有道理,但不能说死
这类说法在科普里很常见,它有一定道理,但需要更精确理解。
有道理的地方
不同头参数不同,所以它们确实可能学出不同关系模式。
在很多可视化研究中,人们也确实观察到:
- 有些头更关注相邻词
- 有些头更关注标点或句法边界
- 有些头对长距离指代敏感
不能说死的地方
但这不意味着某个头有一个永恒固定的人类语法标签。因为:
- 头的功能不是人工预设的
- 同一个头可能混合承担多种关系
- 不同模型、不同层、不同训练阶段分工都可能不同
- 有些头甚至可能逐渐退化为简单模式
所以更准确的说法应该是:
Multi-Head Attention 允许模型在多个子空间中学习不同类型的关系模式,而训练后这些模式常常表现出一定分工。
六、现在进入第二个核心主题:Mask 到底是什么
如果说多头是在回答:
“为什么不能只用一个视角看上下文?”
那么 Mask 回答的是另一个问题:
“为什么不是所有位置都允许互相看?”
这一步非常关键,因为它决定了 Transformer 到底是在做:
- 双向理解
还是 - 自回归生成
我们先回忆一下注意力分数矩阵:
如果不加任何约束,那么每个位置都可以看见所有位置。
也就是说,信息流是全连接的。
这对某些任务是好事,例如句子理解。
但对生成任务来说,会直接出问题。
七、为什么 Decoder 必须有 causal mask
对于 Decoder-only 语言模型,训练目标是:
也就是说,第 (t) 个 token 的预测只能依赖前缀,不能看未来。
如果不加 mask,那么第 (t) 个位置在 Self-Attention 里可以直接看到:
- 第 (t+1) 个 token
- 第 (t+2) 个 token
- ……
这意味着模型在训练时会“偷看答案”。
那它学到的就不是自回归语言建模,而是某种带未来信息的条件建模。
所以 causal mask 的本质不是训练技巧,而是:
强制 Self-Attention 的信息流遵守因果方向。
八、causal mask 的矩阵形式
假设序列长度是 4。
不加 mask 时,score matrix 是一个完整 (4 \times 4) 矩阵:
如果加入 causal mask,那么可见性就必须变成:
换句话说:
- 第 1 个位置只能看自己
- 第 2 个位置只能看前两个位置
- 第 3 个位置只能看前三个位置
- 第 4 个位置可以看前四个位置
所以 causal mask 常常是一个下三角结构。
九、Mask 是怎么真正加进 Attention 里的
mask 的实现方式不是把未来位置“删掉”,而是在 softmax 前给非法位置加一个很大的负数:
其中 (M) 是 mask 矩阵。
对不允许看的位置,(M) 的对应元素通常是:
或实现中用一个极小值,如 (-10^9)。
然后做 softmax:
由于:
所以这些非法位置在 softmax 后的权重就会变成 0。
这说明:
mask 真正做的事情,是在注意力分配前就把非法位置剔除出概率归一化。
这比 softmax 后再置零更合理,因为这样剩余可见位置的权重仍然保持总和为 1。
十、padding mask 和 causal mask 不是一回事
很多人一听到 mask,就只想到 causal mask。
其实在 Transformer 里常见的 mask 至少有两类:
1. padding mask
用于忽略补齐(PAD)位置。
因为 batch 里不同样本长度不同,通常要 pad 到相同长度。
这些 PAD token 不应该参与注意力。
2. causal mask
用于阻止当前位置看未来位置。
这是 Decoder-only 自回归生成成立的必要条件。
所以:
- padding mask 解决的是“哪些位置根本不是有效 token”
- causal mask 解决的是“哪些位置虽然有效,但在因果方向上不允许看”
二者经常一起使用,但角色完全不同。
十一、为什么说 Mask 在改写信息流拓扑
如果从图结构角度看:
- 不加 mask 时,attention 图是全连接图
- 加 causal mask 后,图会变成一个有方向的下三角可达图
换句话说:
Mask 不只是数值层面的遮罩,它实际上改变了 Transformer 中信息传播的合法路径。
这一点很重要,因为它说明:
- Encoder 之所以适合理解,是因为它允许双向信息流
- Decoder 之所以适合生成,是因为 causal mask 把信息流限制成过去 (\rightarrow) 未来
所以 mask 本质上是在 Transformer 中显式编码“允许的信息流结构”。
十二、现在进入第三个核心主题:FFN 到底是什么
Attention 经常是 Transformer 里最被关注的部分,但实际上 Transformer block 不是只有 attention。
每个 block 通常还有一个非常重要的模块:
FFN(Feed Forward Network)
它的经典形式是:
这里:
- (x \in \mathbb{R}^{d_{\text{model}}})
- (W_1 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}})
- (W_2 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}})
- (\sigma) 是非线性激活,例如 ReLU、GELU、SwiGLU 等
FFN 看起来像一个非常普通的两层 MLP。
但在 Transformer 里,它承担的是一个 attention 无法替代的角色。
十三、为什么 Attention 后面还必须接 FFN
这是 FFN 存在的根本原因。
Attention 的本质是:
也就是说,它更像是一个:
- 全局相关性计算器
- 动态加权聚合器
- 信息交换与路由机制
它擅长的是:
- 决定谁该看谁
- 把全局信息按权重混合进来
但它不擅长的是:
对聚合后的单个位置表示做足够强的逐位置非线性加工。
所以 Transformer 不能只有 attention。
否则它会更像一个“上下文混合系统”,而不是一个强表达能力的深层网络。
FFN 的作用就是:
在 attention 已经完成信息交换后,对每个位置自己的表示再做一次高维非线性变换。
十四、FFN 为什么通常是“先升维再降维”
FFN 通常采用:
其中 (d_{\text{ff}}) 往往比 (d_{\text{model}}) 大很多,例如 4 倍。
为什么这么设计?
1. 提升表示能力
先把向量投影到更高维空间,相当于给模型更多中间特征通道,用来表达更复杂的组合模式。
2. 在高维空间中做非线性展开
更高维意味着:
- 可以容纳更多特征方向
- 可以形成更细致的激活模式
- 可以让语义、句法、主题等因素在不同通道上被区分和重组
3. 再压回模型维度
把加工后的高维特征重新整合成统一主空间表示,供下一层继续使用。
所以 FFN 的核心逻辑是:
低维输入 → 高维展开 → 非线性加工 → 压回主维度
十五、为什么 FFN 是逐位置独立做的
这点也很关键。
设 attention 输出后的序列表示是:
FFN 会对每个位置单独作用:
也就是说,它不再让位置之间互相看,而是:
对每个 token 自己的向量表示做同一套非线性变换。
为什么这么设计?
因为 attention 已经完成了跨位置的信息交互。
此时每个位置的向量里,已经包含了所需的上下文信息。
接下来更合理的事,就是让每个位置:
- 自己消化上下文
- 自己做更深层特征重组
- 自己变成更适合下一层使用的表示
所以:
- attention 负责“对外交流”
- FFN 负责“对内加工”
这两者职责不同,但缺一不可。
十六、从几何和实际意义上理解 FFN
如果把 Attention 理解成“建立关系并进行全局加权”,
那么 FFN 更像:
对单个位置的高维特征再加工。
它的第一层 (W_1) 可以看成一组 learned feature detectors:
- 某些方向对语法模式敏感
- 某些方向对语义模式敏感
- 某些方向对主题结构敏感
非线性激活相当于一个门:
- 重要模式被激活
- 无关模式被压制
第二层 (W_2) 再把激活后的高维特征重新组合,形成新的输出表示。
所以 FFN 不只是“再接一层 MLP”,而是在承担:
逐位置表示增强器 的角色。
十七、把这一篇真正串起来看:Transformer 为什么不只是一个 Attention 公式
现在你应该能真正理解,Transformer 的 block 远不只是:
因为如果只有这个公式,会缺少很多关键能力:
缺少多头
就只能用一个视角看上下文。
缺少 mask
就无法区分“理解模型”和“生成模型”,Decoder 也会偷看未来。
缺少 FFN
就只能做信息混合,缺少逐位置非线性加工能力。
所以真正的 Transformer 是:
- 多头 attention:多视角建模
- mask:控制信息流拓扑
- FFN:增强逐位置表示能力
这三者共同构成了 Transformer Block 内部最关键的表达骨架。
十八、本篇真正要记住的三条主线
第一条:为什么要多头
- 单头 attention 只有一个相关性视角
- 语言关系多种多样
- 多头允许模型在多个子空间中并行学习不同类型的关系模式
第二条:为什么需要 mask
- Self-Attention 原本允许任意位置看任意位置
- 生成模型必须遵守因果方向
- causal mask 用来阻止当前位置偷看未来
- padding mask 则用来屏蔽无效填充位置
第三条:FFN 为什么不能省
- attention 负责信息交换
- FFN 负责逐位置非线性加工
- 两者职责完全不同
- FFN 通过“先升维再降维”提升表示能力
十九、用一句话压缩本篇
Transformer 之所以不只是一个 Attention 公式,是因为它不仅需要通过 Multi-Head Attention 在多个子空间中建模不同关系,还必须通过 Mask 控制合法的信息流方向,并通过 FFN 对每个位置的聚合结果做进一步的高维非线性加工。
下一篇预告
下一篇主题是:
一层 Transformer Block 到底是什么:残差、LayerNorm 与 Encoder/Decoder 全貌
会重点讲清楚:
- 为什么深层网络一定需要残差连接
- LayerNorm 到底在归一化什么
- Pre-LN 和 Post-LN 有什么差别
- 一层 Transformer block 从输入到输出到底经历几步
- 原始 Transformer 为什么是 Encoder-Decoder
- Encoder-only、Decoder-only、Encoder-Decoder 分别适合什么任务
