AI Infra学习之旅-手算一个完整TransformerBlock
AI Infra学习之旅-手算一个完整TransformerBlock
这是“Transformer 原理深讲系列”的第 7 篇。
上一篇我们已经手算了一次最小的单头 Self-Attention,从输入向量一路算到 attention 输出。
但如果只停在那里,你对 Transformer 的理解还差最后一块关键拼图:
真实的一层 Transformer block 并不只有 attention,它还包含残差连接、LayerNorm,以及 attention 之后的 FFN。
所以这一篇我们继续沿着“手算”的方式,把一个极小的 Transformer block 真正走完整。
一、为什么还要再手算一个完整 block
上一篇我们已经算出了:
也看到了 attention 的本质是:
- 先用 Q/K 计算相关性
- 再 softmax 得到注意力分配
- 最后对 V 做加权求和
但如果你回头看真实 Transformer 的结构,就会发现:
一个 block 并不是:
输入 (\rightarrow) Attention (\rightarrow) 输出
而是至少包含:
- Attention 子层
- 残差连接
- LayerNorm
- FFN 子层
- 再一次残差连接
- 再一次 LayerNorm(或 Pre-LN 形式下先 LN 再进子层)
也就是说,attention 虽然是“心脏”,但它并不是整个 block 的全部。
所以这一篇的目标就是:
把 attention 结果真正接进残差、LayerNorm 和 FFN,完整算出一层 Transformer Block 的输出。
二、这次我们仍然使用极小可手算例子
为了保证能一步一步写清楚,这次继续做简化。
我们保留的东西
- 输入向量
- 单头 Self-Attention
- 残差连接
- LayerNorm
- FFN
- 完整一层 Transformer Block 的顺序
我们暂时不展开的东西
- 多头 attention
- causal mask
- batch 维度
- 多层堆叠
- 输出 logits / 词表概率
- Pre-LN 与 Post-LN 的所有变体比较(我们这次固定一种)
三、我们这次采用哪一种 block 结构
为了手算更清晰,这一篇采用 Post-LN 风格 的 block 来演示:
第一步:Attention 子层
第二步:FFN 子层
这不是说现代大模型一定都这样写。
真实工程中很多模型更偏 Pre-LN。
但对于手算演示来说,这种形式更适合把:
- 子层输出
- 残差叠加
- 再归一化
清楚地分成两步看。
四、我们直接复用上一讲 attention 的输入
上一讲我们设输入矩阵为:
序列长度为:
模型维度:
并且通过一个极小例子,已经算出单头 Self-Attention 输出近似为:
为了这一篇聚焦 block,本篇就不再重复 Q/K/V 的中间计算,而是直接从这个 attention 输出继续往后算。
五、第一步:做 Attention 残差连接
根据 block 结构,我们先做:
把输入矩阵和 attention 输出逐元素相加:
这一步的意义非常重要:
- attention 给出的是“上下文修正量”
- 残差让我们不是丢掉原输入,而是把修正量叠加到原表示上
所以现在的表示已经不是“纯输入”或“纯 attention”,而是:
原表示 + 上下文修正后的表示
六、第二步:对 Attention 残差结果做 LayerNorm
现在对 (R^{(1)}) 的每一行分别做 LayerNorm。
为了手算清晰,我们做如下简化:
- (\gamma = [1,1])
- (\beta = [0,0])
- 忽略 (\epsilon) 的极小影响
于是 LayerNorm 退化成:
其中:
因为这里每一行只有 2 维,所以会算得特别简洁。
第 1 行归一化
第 1 行是:
均值:
方差:
标准差:
归一化后:
第 2 行归一化
第 2 行是:
均值:
两边偏差分别是:
- (1.504-1.628=-0.124)
- (1.752-1.628=0.124)
标准差就是:
归一化后:
第 3 行归一化
第 3 行是:
均值:
偏差:
- (2.576-2.146=0.43)
- (1.716-2.146=-0.43)
标准差:
归一化后:
得到 Attention 子层输出
所以:
这就是 attention 子层完整走完后的输出。
现在你可以真正看到一件事:
- attention 先改变表示
- 残差把原表示和修正量合并
- LayerNorm 再把结果拉回到稳定尺度
这三步一起,才构成 attention 子层的完整意义。
七、现在进入 FFN 子层
接下来要算:
所以我们先定义一个极小的 FFN。
八、定义一个可手算的 FFN
为了方便手算,我们设:
也就是 FFN 先把 2 维向量升到 3 维,再压回 2 维。
设第一层权重:
注意这里我们把输入行向量右乘 (W_1),所以 (W_1) 形状是 (2\times3)。
再设第二层权重:
它的形状是 (3\times2)。
激活函数我们用最简单的 ReLU:
所以 FFN 写成:
九、先算第 1 行的 FFN
第 1 行输入是:
先乘 (W_1):
解释一下:
- 第 1 维:(1\times1 + (-1)\times0 = 1)
- 第 2 维:(1\times0 + (-1)\times1 = -1)
- 第 3 维:(1\times1 + (-1)\times(-1)=2)
经过 ReLU:
再乘 (W_2):
因为:
- 第 1 维:(1\times1 + 0\times0 + 2\times1 = 3)
- 第 2 维:(1\times0 + 0\times1 + 2\times1 = 2)
所以:
十、算第 2 行的 FFN
第 2 行输入是:
先乘 (W_1):
经过 ReLU:
再乘 (W_2):
所以:
十一、算第 3 行的 FFN
第 3 行输入是:
和第 1 行一样,所以结果也是:
十二、于是 FFN 输出矩阵是
现在 attention 子层的输出已经经过了 FFN 子层加工准备。
十三、做 FFN 残差连接
根据 block 结构,先做残差:
所以:
这一步的含义和前面的 attention 残差一致:
- FFN 不是完全替换原表示
- 而是在原表示上叠加一个“逐位置非线性修正量”
十四、最后一步:再做一次 LayerNorm
现在对 (R^{(2)}) 的每一行做 LayerNorm。
继续使用:
- (\gamma=[1,1])
- (\beta=[0,0])
第 1 行归一化
均值:
偏差:
- (4-2.5=1.5)
- (1-2.5=-1.5)
标准差:
归一化后:
第 2 行归一化
均值:
偏差:
- (-1-0.5=-1.5)
- (2-0.5=1.5)
标准差:
归一化后:
第 3 行归一化
和第 1 行一样:
十五、得到完整 block 的最终输出
所以这一层完整 Transformer block 的最终输出是:
这就是我们从:
- 输入 (X)
- attention
- 残差
- LayerNorm
- FFN
- 残差
- LayerNorm
一路完整算出来的一层 Transformer Block 输出。
十六、现在回头看:这一层 Transformer Block 到底发生了什么
虽然这个例子非常小,但它已经完整体现了一层 Transformer Block 的本质节奏:
第一步:通过 attention 做跨位置的信息交换
每个 token 都不是只保留自己,而是开始吸收别的位置的信息。
第二步:通过第一条残差保留原表示通路
attention 只做修正,不是彻底重写。
第三步:通过 LayerNorm 把表示重新拉回稳定尺度
防止数值分布乱掉。
第四步:通过 FFN 做逐位置非线性加工
attention 负责“交流”,FFN 负责“消化”。
第五步:再用残差和 LayerNorm 稳定加工结果
让整层 block 变成一个可堆叠的稳定单元。
这就是一层 Transformer block 最本质的工作模式:
先做全局交互,再做局部加工;每一步都通过残差和归一化维持可训练性。
十七、你可能会发现:为什么最终结果又回到了 ([1,-1]) 和 ([-1,1])
这是因为我们这次为了手算方便:
- 把 LayerNorm 简化成了标准化
- 且每行只有 2 个维度
- 维度太小,会导致很多行自然被归一化成对称形式
所以这个现象不是说“真实模型一层算完永远回到原样”,而是这个极小例子下的自然结果。
在真实模型里:
- 维度通常是几百到几千
- (\gamma,\beta) 是可学习参数
- LN 之后不会这么简单地塌到 (\pm1)
- attention 和 FFN 的权重也复杂得多
所以你要看重的不是“最后这几个数值形式”,而是:
整个 block 的数值流动顺序。
十八、从线性代数视角把这一层压缩一遍
如果把这一篇的过程压缩成最核心的数学链条,就是:
输入
attention 子层
第一条残差与归一化
FFN 子层
第二条残差与归一化
这就是一个最小 block 的完整 forward。
十九、为什么这一篇对真正理解 Transformer 特别重要
因为很多人即使知道:
- attention 是什么
- FFN 是什么
- 残差和 LN 是什么
也依然没有真正把它们在时间顺序和数值顺序上串起来。
而这一篇手算以后,你应该能真正建立一个很具体的感觉:
- attention 子层不是独立悬浮的
- FFN 不是附赠 MLP
- 残差和 LN 不是边角料
- 一层 Transformer Block 是一条非常清晰的数值处理链
这会让你回头看真正的 Transformer 实现时,理解深度完全不一样。
二十、本篇真正要记住的四条主线
第一条:一层 Transformer Block 的核心节奏
- attention 做跨位置交互
- FFN 做逐位置加工
第二条:残差的作用
- 不是替换表示,而是叠加修正量
- 让信息和梯度更容易传递
第三条:LayerNorm 的作用
- 让每一步加工后的表示回到更稳定的数值分布
- 保证深层可训练性
第四条:完整 block 的 forward
这是一层 Transformer Block 最核心的数学骨架。
二十一、用一句话压缩本篇
一个完整 Transformer block 的本质,是先通过 attention 对序列表示做一次跨位置上下文修正,再通过 FFN 对每个位置做一次逐位置非线性加工,并用残差与 LayerNorm 将这两步稳定地组织成一个可堆叠的深层更新单元。
二十二、下一篇预告
下一篇主题是:
用 PyTorch 从零实现一个最小 Transformer:从 Self-Attention 到可运行模型
会重点讲清楚:
- 如何用 PyTorch 把 Q/K/V 投影、Attention、残差、LayerNorm、FFN 全部写成代码
- causal mask 在代码里怎么对应推理的因果约束
- 一个可运行的最小 Decoder-only Transformer 完整实现
- 朴素 generate() 的写法,以及它暴露的推理瓶颈
