AI Infra学习之旅-手算一个完整TransformerBlock

这是“Transformer 原理深讲系列”的第 7 篇。
上一篇我们已经手算了一次最小的单头 Self-Attention,从输入向量一路算到 attention 输出。
但如果只停在那里,你对 Transformer 的理解还差最后一块关键拼图:
真实的一层 Transformer block 并不只有 attention,它还包含残差连接、LayerNorm,以及 attention 之后的 FFN。
所以这一篇我们继续沿着“手算”的方式,把一个极小的 Transformer block 真正走完整。


一、为什么还要再手算一个完整 block

上一篇我们已经算出了:

O=Attention(H)O = \text{Attention}(H)

也看到了 attention 的本质是:

  • 先用 Q/K 计算相关性
  • 再 softmax 得到注意力分配
  • 最后对 V 做加权求和

但如果你回头看真实 Transformer 的结构,就会发现:

一个 block 并不是:

输入 (\rightarrow) Attention (\rightarrow) 输出

而是至少包含:

  1. Attention 子层
  2. 残差连接
  3. LayerNorm
  4. FFN 子层
  5. 再一次残差连接
  6. 再一次 LayerNorm(或 Pre-LN 形式下先 LN 再进子层)

也就是说,attention 虽然是“心脏”,但它并不是整个 block 的全部。

所以这一篇的目标就是:

把 attention 结果真正接进残差、LayerNorm 和 FFN,完整算出一层 Transformer Block 的输出。


二、这次我们仍然使用极小可手算例子

为了保证能一步一步写清楚,这次继续做简化。

我们保留的东西

  • 输入向量
  • 单头 Self-Attention
  • 残差连接
  • LayerNorm
  • FFN
  • 完整一层 Transformer Block 的顺序

我们暂时不展开的东西

  • 多头 attention
  • causal mask
  • batch 维度
  • 多层堆叠
  • 输出 logits / 词表概率
  • Pre-LN 与 Post-LN 的所有变体比较(我们这次固定一种)

三、我们这次采用哪一种 block 结构

为了手算更清晰,这一篇采用 Post-LN 风格 的 block 来演示:

第一步:Attention 子层

H(1)=LN(X+Attention(X))H^{(1)} = \text{LN}(X + \text{Attention}(X))

第二步:FFN 子层

H(2)=LN(H(1)+FFN(H(1)))H^{(2)} = \text{LN}(H^{(1)} + \text{FFN}(H^{(1)}))

这不是说现代大模型一定都这样写。
真实工程中很多模型更偏 Pre-LN。
但对于手算演示来说,这种形式更适合把:

  • 子层输出
  • 残差叠加
  • 再归一化

清楚地分成两步看。


四、我们直接复用上一讲 attention 的输入

上一讲我们设输入矩阵为:

X=[100111]X= \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{bmatrix}

序列长度为:

n=3n=3

模型维度:

dmodel=2d_{\text{model}}=2

并且通过一个极小例子,已经算出单头 Self-Attention 输出近似为:

A(X)[1.4010.5991.5040.7521.5760.716]A(X) \approx \begin{bmatrix} 1.401 & 0.599 \\ 1.504 & 0.752 \\ 1.576 & 0.716 \end{bmatrix}

为了这一篇聚焦 block,本篇就不再重复 Q/K/V 的中间计算,而是直接从这个 attention 输出继续往后算。


五、第一步:做 Attention 残差连接

根据 block 结构,我们先做:

R(1)=X+A(X)R^{(1)} = X + A(X)

把输入矩阵和 attention 输出逐元素相加:

R(1)=[100111]+[1.4010.5991.5040.7521.5760.716]=[2.4010.5991.5041.7522.5761.716]R^{(1)}= \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{bmatrix} + \begin{bmatrix} 1.401 & 0.599 \\ 1.504 & 0.752 \\ 1.576 & 0.716 \end{bmatrix} = \begin{bmatrix} 2.401 & 0.599 \\ 1.504 & 1.752 \\ 2.576 & 1.716 \end{bmatrix}

这一步的意义非常重要:

  • attention 给出的是“上下文修正量”
  • 残差让我们不是丢掉原输入,而是把修正量叠加到原表示上

所以现在的表示已经不是“纯输入”或“纯 attention”,而是:

原表示 + 上下文修正后的表示


六、第二步:对 Attention 残差结果做 LayerNorm

现在对 (R^{(1)}) 的每一行分别做 LayerNorm。

为了手算清晰,我们做如下简化:

  • (\gamma = [1,1])
  • (\beta = [0,0])
  • 忽略 (\epsilon) 的极小影响

于是 LayerNorm 退化成:

LN(x)=xμσ\text{LN}(x)=\frac{x-\mu}{\sigma}

其中:

μ=1dixi,σ=1di(xiμ)2\mu = \frac{1}{d}\sum_i x_i,\quad \sigma = \sqrt{\frac{1}{d}\sum_i (x_i-\mu)^2}

因为这里每一行只有 2 维,所以会算得特别简洁。


第 1 行归一化

第 1 行是:

[2.401, 0.599][2.401,\ 0.599]

均值:

μ=2.401+0.5992=1.5\mu = \frac{2.401+0.599}{2}=1.5

方差:

(2.4011.5)2+(0.5991.5)22=0.9012+(0.901)22=0.8118\frac{(2.401-1.5)^2 + (0.599-1.5)^2}{2} = \frac{0.901^2 + (-0.901)^2}{2} = 0.8118

标准差:

σ0.901\sigma \approx 0.901

归一化后:

[2.4011.50.901,0.5991.50.901]=[1, 1]\left[ \frac{2.401-1.5}{0.901}, \frac{0.599-1.5}{0.901} \right] = [1,\ -1]


第 2 行归一化

第 2 行是:

[1.504, 1.752][1.504,\ 1.752]

均值:

μ=1.504+1.7522=1.628\mu = \frac{1.504+1.752}{2}=1.628

两边偏差分别是:

  • (1.504-1.628=-0.124)
  • (1.752-1.628=0.124)

标准差就是:

σ=0.124\sigma=0.124

归一化后:

[1, 1][-1,\ 1]


第 3 行归一化

第 3 行是:

[2.576, 1.716][2.576,\ 1.716]

均值:

μ=2.576+1.7162=2.146\mu = \frac{2.576+1.716}{2}=2.146

偏差:

  • (2.576-2.146=0.43)
  • (1.716-2.146=-0.43)

标准差:

σ=0.43\sigma = 0.43

归一化后:

[1, 1][1,\ -1]


得到 Attention 子层输出

所以:

H(1)=LN(R(1))=[111111]H^{(1)} = \text{LN}(R^{(1)}) = \begin{bmatrix} 1 & -1 \\ -1 & 1 \\ 1 & -1 \end{bmatrix}

这就是 attention 子层完整走完后的输出。

现在你可以真正看到一件事:

  • attention 先改变表示
  • 残差把原表示和修正量合并
  • LayerNorm 再把结果拉回到稳定尺度

这三步一起,才构成 attention 子层的完整意义。


七、现在进入 FFN 子层

接下来要算:

H(2)=LN(H(1)+FFN(H(1)))H^{(2)} = \text{LN}(H^{(1)} + \text{FFN}(H^{(1)}))

所以我们先定义一个极小的 FFN。


八、定义一个可手算的 FFN

为了方便手算,我们设:

dmodel=2,dff=3d_{\text{model}}=2,\quad d_{\text{ff}}=3

也就是 FFN 先把 2 维向量升到 3 维,再压回 2 维。

设第一层权重:

W1=[101011]W_1= \begin{bmatrix} 1 & 0 & 1 \\ 0 & 1 & -1 \end{bmatrix}

注意这里我们把输入行向量右乘 (W_1),所以 (W_1) 形状是 (2\times3)。

再设第二层权重:

W2=[100111]W_2= \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{bmatrix}

它的形状是 (3\times2)。

激活函数我们用最简单的 ReLU:

ReLU(z)=max(0,z)\text{ReLU}(z)=\max(0,z)

所以 FFN 写成:

FFN(x)=ReLU(xW1)W2\text{FFN}(x)=\text{ReLU}(xW_1)W_2


九、先算第 1 行的 FFN

第 1 行输入是:

[1, 1][1,\ -1]

先乘 (W_1):

[1, 1][101011]=[1, 1, 2][1,\ -1] \begin{bmatrix} 1 & 0 & 1 \\ 0 & 1 & -1 \end{bmatrix} = [1,\ -1,\ 2]

解释一下:

  • 第 1 维:(1\times1 + (-1)\times0 = 1)
  • 第 2 维:(1\times0 + (-1)\times1 = -1)
  • 第 3 维:(1\times1 + (-1)\times(-1)=2)

经过 ReLU:

[1, 0, 2][1,\ 0,\ 2]

再乘 (W_2):

[1, 0, 2][100111]=[3, 2][1,\ 0,\ 2] \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{bmatrix} = [3,\ 2]

因为:

  • 第 1 维:(1\times1 + 0\times0 + 2\times1 = 3)
  • 第 2 维:(1\times0 + 0\times1 + 2\times1 = 2)

所以:

FFN([1,1])=[3,2]\text{FFN}([1,-1]) = [3,2]


十、算第 2 行的 FFN

第 2 行输入是:

[1, 1][-1,\ 1]

先乘 (W_1):

[1, 1][101011]=[1, 1, 2][-1,\ 1] \begin{bmatrix} 1 & 0 & 1 \\ 0 & 1 & -1 \end{bmatrix} = [-1,\ 1,\ -2]

经过 ReLU:

[0, 1, 0][0,\ 1,\ 0]

再乘 (W_2):

[0, 1, 0][100111]=[0, 1][0,\ 1,\ 0] \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{bmatrix} = [0,\ 1]

所以:

FFN([1,1])=[0,1]\text{FFN}([-1,1])=[0,1]


十一、算第 3 行的 FFN

第 3 行输入是:

[1, 1][1,\ -1]

和第 1 行一样,所以结果也是:

[3, 2][3,\ 2]


十二、于是 FFN 输出矩阵是

FFN(H(1))=[320132]\text{FFN}(H^{(1)})= \begin{bmatrix} 3 & 2 \\ 0 & 1 \\ 3 & 2 \end{bmatrix}

现在 attention 子层的输出已经经过了 FFN 子层加工准备。


十三、做 FFN 残差连接

根据 block 结构,先做残差:

R(2)=H(1)+FFN(H(1))R^{(2)} = H^{(1)} + \text{FFN}(H^{(1)})

所以:

R(2)=[111111]+[320132]=[411241]R^{(2)}= \begin{bmatrix} 1 & -1 \\ -1 & 1 \\ 1 & -1 \end{bmatrix} + \begin{bmatrix} 3 & 2 \\ 0 & 1 \\ 3 & 2 \end{bmatrix} = \begin{bmatrix} 4 & 1 \\ -1 & 2 \\ 4 & 1 \end{bmatrix}

这一步的含义和前面的 attention 残差一致:

  • FFN 不是完全替换原表示
  • 而是在原表示上叠加一个“逐位置非线性修正量”

十四、最后一步:再做一次 LayerNorm

现在对 (R^{(2)}) 的每一行做 LayerNorm。

继续使用:

  • (\gamma=[1,1])
  • (\beta=[0,0])

第 1 行归一化

[4, 1][4,\ 1]

均值:

μ=4+12=2.5\mu = \frac{4+1}{2}=2.5

偏差:

  • (4-2.5=1.5)
  • (1-2.5=-1.5)

标准差:

σ=1.5\sigma=1.5

归一化后:

[1, 1][1,\ -1]


第 2 行归一化

[1, 2][-1,\ 2]

均值:

μ=1+22=0.5\mu = \frac{-1+2}{2}=0.5

偏差:

  • (-1-0.5=-1.5)
  • (2-0.5=1.5)

标准差:

σ=1.5\sigma=1.5

归一化后:

[1, 1][-1,\ 1]


第 3 行归一化

和第 1 行一样:

[1, 1][1,\ -1]


十五、得到完整 block 的最终输出

所以这一层完整 Transformer block 的最终输出是:

H(2)=[111111]H^{(2)}= \begin{bmatrix} 1 & -1 \\ -1 & 1 \\ 1 & -1 \end{bmatrix}

这就是我们从:

  • 输入 (X)
  • attention
  • 残差
  • LayerNorm
  • FFN
  • 残差
  • LayerNorm

一路完整算出来的一层 Transformer Block 输出。


十六、现在回头看:这一层 Transformer Block 到底发生了什么

虽然这个例子非常小,但它已经完整体现了一层 Transformer Block 的本质节奏:

第一步:通过 attention 做跨位置的信息交换

每个 token 都不是只保留自己,而是开始吸收别的位置的信息。

第二步:通过第一条残差保留原表示通路

attention 只做修正,不是彻底重写。

第三步:通过 LayerNorm 把表示重新拉回稳定尺度

防止数值分布乱掉。

第四步:通过 FFN 做逐位置非线性加工

attention 负责“交流”,FFN 负责“消化”。

第五步:再用残差和 LayerNorm 稳定加工结果

让整层 block 变成一个可堆叠的稳定单元。

这就是一层 Transformer block 最本质的工作模式:

先做全局交互,再做局部加工;每一步都通过残差和归一化维持可训练性。


十七、你可能会发现:为什么最终结果又回到了 ([1,-1]) 和 ([-1,1])

这是因为我们这次为了手算方便:

  • 把 LayerNorm 简化成了标准化
  • 且每行只有 2 个维度
  • 维度太小,会导致很多行自然被归一化成对称形式

所以这个现象不是说“真实模型一层算完永远回到原样”,而是这个极小例子下的自然结果。

在真实模型里:

  • 维度通常是几百到几千
  • (\gamma,\beta) 是可学习参数
  • LN 之后不会这么简单地塌到 (\pm1)
  • attention 和 FFN 的权重也复杂得多

所以你要看重的不是“最后这几个数值形式”,而是:

整个 block 的数值流动顺序。


十八、从线性代数视角把这一层压缩一遍

如果把这一篇的过程压缩成最核心的数学链条,就是:

输入

XX

attention 子层

A(X)=softmax(QKTdk)VA(X)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

第一条残差与归一化

H(1)=LN(X+A(X))H^{(1)}=\text{LN}(X+A(X))

FFN 子层

F(H(1))=ReLU(H(1)W1)W2F(H^{(1)})=\text{ReLU}(H^{(1)}W_1)W_2

第二条残差与归一化

H(2)=LN(H(1)+F(H(1)))H^{(2)}=\text{LN}(H^{(1)}+F(H^{(1)}))

这就是一个最小 block 的完整 forward。


十九、为什么这一篇对真正理解 Transformer 特别重要

因为很多人即使知道:

  • attention 是什么
  • FFN 是什么
  • 残差和 LN 是什么

也依然没有真正把它们在时间顺序和数值顺序上串起来。

而这一篇手算以后,你应该能真正建立一个很具体的感觉:

  • attention 子层不是独立悬浮的
  • FFN 不是附赠 MLP
  • 残差和 LN 不是边角料
  • 一层 Transformer Block 是一条非常清晰的数值处理链

这会让你回头看真正的 Transformer 实现时,理解深度完全不一样。


二十、本篇真正要记住的四条主线

第一条:一层 Transformer Block 的核心节奏

  • attention 做跨位置交互
  • FFN 做逐位置加工

第二条:残差的作用

  • 不是替换表示,而是叠加修正量
  • 让信息和梯度更容易传递

第三条:LayerNorm 的作用

  • 让每一步加工后的表示回到更稳定的数值分布
  • 保证深层可训练性

第四条:完整 block 的 forward

H(1)=LN(X+Attention(X))H^{(1)}=\text{LN}(X+\text{Attention}(X))

H(2)=LN(H(1)+FFN(H(1)))H^{(2)}=\text{LN}(H^{(1)}+\text{FFN}(H^{(1)}))

这是一层 Transformer Block 最核心的数学骨架。


二十一、用一句话压缩本篇

一个完整 Transformer block 的本质,是先通过 attention 对序列表示做一次跨位置上下文修正,再通过 FFN 对每个位置做一次逐位置非线性加工,并用残差与 LayerNorm 将这两步稳定地组织成一个可堆叠的深层更新单元。


二十二、下一篇预告

下一篇主题是:

用 PyTorch 从零实现一个最小 Transformer:从 Self-Attention 到可运行模型

会重点讲清楚:

  • 如何用 PyTorch 把 Q/K/V 投影、Attention、残差、LayerNorm、FFN 全部写成代码
  • causal mask 在代码里怎么对应推理的因果约束
  • 一个可运行的最小 Decoder-only Transformer 完整实现
  • 朴素 generate() 的写法,以及它暴露的推理瓶颈