很多人学 Transformer,会先记住一串关键词:Self-Attention、Multi-Head、FFN、LayerNorm、残差连接。这样学并不算错,但一旦问题从“结构是什么”变成“为什么推理会慢”“为什么要有 KV Cache”“FlashAttention 和 PagedAttention 分别在解决什么”,就容易断层。

站在 AI Infra 的视角,Transformer 不是一套论文术语,而是一条要真正跑在 GPU 上的数据流。只有把这条数据流看清楚,后面的推理优化、显存管理和系统设计才有落点。

这篇文章不打算把 Transformer 讲成百科全书,而是抓住一条主线:从一个 decoder-only Transformer block 出发,顺着前向计算一路走到推理瓶颈。

先把整体图景立起来

今天大模型推理里最常见的是 decoder-only Transformer。把它压缩成一条最核心的执行链,大致可以写成这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
token ids
-> embedding
-> 多层 Transformer block
-> norm
-> self-attention
-> residual
-> norm
-> MLP
-> residual
-> final norm
-> lm head
-> logits
-> next token

如果只看单层 block,它通常可以抽象成两步:

1
2
h1 = x + Attention(Norm(x))
h2 = h1 + MLP(Norm(h1))

这两行式子很重要,因为后面几乎所有问题都能往这里落:

  • Attention 在做什么。
  • MLP 在做什么。
  • 为什么要先 Norm 再做变换。
  • 为什么要保留残差连接。
  • 推理时计算和访存主要卡在哪一步。

一个 Transformer block 到底在算什么

Attention 负责“跨 token 混合信息”

Attention 的核心不是“公式很高级”,而是它让当前位置可以直接读取上下文里和自己最相关的信息。

写成公式是:

1
2
3
4
5
Q = XWq
K = XWk
V = XWv

Attention(Q, K, V) = softmax(QK^T / sqrt(d)) V

直觉上可以这样理解:

  • Q 表示当前 token 想找什么信息。
  • K 表示每个历史位置手里有什么“索引”。
  • V 表示真正要被取出的内容。

所以 Attention 做的事情其实很朴素:先算相关性,再按权重把历史信息汇总回来。

这也是 Transformer 和 RNN 的一个根本区别。RNN 要一格一格传递状态,Attention 可以直接在任意两个位置之间建立依赖。对长上下文来说,这件事非常关键。

MLP 负责“在单个 token 内部继续加工表示”

很多人会下意识觉得 Attention 是主角,MLP 只是配角。实际不是这样。

Attention 更像是在做“信息路由”,决定应该从上下文里看什么;MLP 则更像是在做“通道内加工”,把当前 token 的表示继续变换、拉伸、压缩、重组。

可以粗略记成一句话:

  • Attention 解决“看哪里”。
  • MLP 解决“怎么加工”。

在现代 LLM 里,MLP 往往不是最朴素的两层全连接,而是带门控的结构,比如 SwiGLU。这样做的目的,是让信息通过时更有选择性,表达能力也更强。

Norm 和残差不是陪衬,而是稳定系统的骨架

如果把 Attention 和 MLP 看成“变换器”,那 Norm 和残差就是“稳定器”。

Norm 的作用,是把每一层输入控制在一个更可管理的数值范围内。否则层数一深,激活分布漂移,训练和推理都会越来越不稳定。现在很多 LLM 更偏向 RMSNorm,一个现实原因就是它计算更简单,工程实现也更友好。

残差连接的意义则更直接:每一层不是把原表示推翻重来,而是在原表示上做增量修正。这样既能保留主干信息,也让深层网络更容易优化。

如果面试里要一句话说清楚,可以这么讲:

Attention 和 MLP 负责改变表示,Norm 和残差负责让这种改变可控、可叠加、可持续。

为什么同一个 Transformer,训练和推理看起来像两种系统

这是 AI Infra 里最关键的分界线之一。

训练时,模型面对的是完整已知序列。虽然目标仍然是 next-token prediction,但所有位置都已经在输入里了,所以很多计算可以并行做。

推理时则完全不同。模型只能基于当前已有前缀,一个 token 一个 token 往前生成。未来位置还不存在,因此很多本来可以一起展开的计算,必须拆成多次小步执行。

这就直接带来了两个阶段:

Prefill:先把整段 prompt 编进模型

当用户输入一长段 prompt 时,系统会先把这些 token 整体跑一遍。这个阶段序列长、矩阵大、计算密集,通常更偏向算力瓶颈。

Decode:之后每一步只生成一个新 token

进入生成阶段后,每次新 token 都要走完整个模型,但这时单次新增计算并不大,问题反而变成了:为了生成这一个 token,需要不断回头读取整段历史上下文的信息。

所以一个很实用的结论是:

Prefill 更像大矩阵计算问题,Decode 更像高频访存问题。

这也是为什么很多推理系统会把 TTFT 和 TPOT 分开看,因为它们对应的性能画像本来就不一样。

KV Cache 为什么是推理的关键

如果没有 KV Cache,模型每生成一个新 token,都要把历史序列从头算一遍 Attention。这在长上下文下几乎不可接受。

KV Cache 的思路很直接:历史 token 已经算出来的 KV 不要丢,按层缓存起来。下一个 token 到来时,只计算它自己的新 Q/K/V,然后用当前 Q 去读取历史缓存里的 K/V

也就是说,缓存的不是整个 block 的全部中间结果,而是 Attention 阶段最值得复用的那部分状态。

为什么只存 K/V,不存 Q

因为 Q 只为“当前这一步的查询”服务,它做完这一步就没有复用价值了。真正会在后续每一步都被重复访问的是历史 K/V

所以 KV Cache 的本质不是一个“顺手优化”,而是把原本无法承受的重复计算,转换成了更可管理的显存占用和访存开销。

为什么 Decode 往往会变成 memory-bound

很多人第一次接触推理优化时,会默认以为大模型慢是因为“算得太多”。这句话只说对一半。

在 Decode 阶段,每一步新增的计算量并不一定夸张,但每一层都要去读很长的历史 K/V。上下文一长,这件事对显存带宽和访存模式的压力就会迅速上来。

因此 Decode 常见的真实瓶颈是:

  • 不是 ALU 不够忙;
  • 而是 GPU 在等数据;
  • 不是矩阵乘本身算不动;
  • 而是历史 KV 的读取代价越来越高。

这也是为什么很多 LLM 推理优化,到最后都会落到缓存布局、访存路径和调度策略上。

FlashAttention 和 PagedAttention,不是一回事

这两个名字经常被一起提,但它们解决的问题并不相同。

FlashAttention:减少 Attention 中间结果的显存读写

FlashAttention 的核心不是“换了一个公式”,而是重新组织计算顺序,尽量把中间结果留在更快的片上存储里,减少大规模的显存往返。

它主要解决的是:标准 Attention 在长序列下会产生大量中间张量,显存访问成本很高。

所以 FlashAttention 更偏向算子层优化,重点是让 Attention 这一步本身更高效。

PagedAttention:把 KV Cache 管理得更像虚拟内存

PagedAttention 的重点不在于单次 Attention 的数学计算,而在于怎么组织和管理 KV Cache

如果把每个请求的 KV 都按连续大块显存分配,动态 batch 和不同长度请求混在一起时,很容易出现碎片化和空间浪费。PagedAttention 的思路,是把 KV Cache 拆成页,用更灵活的方式映射和复用。

所以它更偏向系统层优化,重点是让多请求场景下的缓存管理更稳定、更省显存。

可以把两者区别压缩成一句话:

FlashAttention 优化的是“怎么高效算 Attention”,PagedAttention 优化的是“怎么高效管 KV Cache”。

如果顺着一次前向计算往下看,瓶颈会落在哪里

站在推理视角,一个 block 的执行过程大致是这样的:

  1. 输入 hidden state 先做 Norm。
  2. 投影出当前步的 Q/K/V
  3. 新产生的 K/V 写入 KV Cache。
  4. 当前 Q 去读取历史 K/V,完成注意力计算。
  5. Attention 输出做线性投影并加上残差。
  6. 再做一次 Norm。
  7. 进入 MLP。
  8. MLP 输出再加上残差,送往下一层。

如果是 Prefill 阶段,Attention 和 MLP 往往都比较“像矩阵乘问题”,吞吐更受算子效率影响。

如果是 Decode 阶段,最敏感的地方通常变成了:

  • KV Cache 写得是否规整;
  • 历史 K/V 读得是否高效;
  • 不同请求能不能顺畅拼 batch;
  • 显存是不是被碎片和冗余缓存拖垮。

所以你会看到一个很典型的现象:很多人把 Transformer 结构背得很熟,但一谈推理优化就只会复述“用 FlashAttention、用 KV Cache”。真正有用的理解,是能把优化名词和执行路径一一对应起来。

站在 AI Infra 视角,真正要掌握到什么程度

如果你的目标是推理工程、模型部署或系统优化岗位,对 Transformer 的理解至少应该过四层:

第一层:能把 block 讲顺

知道输入如何经过 Norm、Attention、MLP 和残差,能解释每个模块在做什么。

第二层:能把训练和推理分开讲

明白为什么训练能并行,推理却要逐 token 展开;知道 Prefill 和 Decode 的差异。

第三层:能把瓶颈落到硬件行为

知道为什么有时是 compute-bound,有时是 memory-bound;知道 KV Cache 对显存和带宽意味着什么。

第四层:能把优化方法和问题本身对应起来

知道 FlashAttention 在解决算子层访存问题,PagedAttention 在解决缓存组织和显存管理问题,而不是把两个词当成“性能优化黑话”。

这四层一旦打通,很多系统问题会自然连起来。你再去看 continuous batching、prefix caching、speculative decoding,理解会顺很多,因为你已经知道推理系统到底在绕着哪些硬约束做取舍。

写在最后

Transformer 真正难的地方,不在于公式多,而在于它一头连着模型表达能力,一头连着真实推理系统。只盯着结构,很容易停在“我知道它由 Attention 和 MLP 组成”;只盯着优化,又容易把 FlashAttention、PagedAttention、KV Cache 这些词学成零散技巧。

更好的理解方式,是始终沿着同一条线往下走:

从一个 block 的前向计算开始,看到 Attention 如何读取上下文,看到 MLP 如何加工表示,看到推理为什么拆成 Prefill 和 Decode,看到 KV Cache 为什么成为显存核心,再看到 FlashAttention 和 PagedAttention 分别在解决哪类问题。

把这条线走通之后,Transformer 对你来说就不再是一张结构图,而是一套真正能落到推理系统上的执行逻辑。