AI-Infra学习之旅-从 Transformer Block 到 KV Cache:站在推理视角理解 Transformer
很多人学 Transformer,会先记住一串关键词:Self-Attention、Multi-Head、FFN、LayerNorm、残差连接。这样学并不算错,但一旦问题从“结构是什么”变成“为什么推理会慢”“为什么要有 KV Cache”“FlashAttention 和 PagedAttention 分别在解决什么”,就容易断层。
站在 AI Infra 的视角,Transformer 不是一套论文术语,而是一条要真正跑在 GPU 上的数据流。只有把这条数据流看清楚,后面的推理优化、显存管理和系统设计才有落点。
这篇文章不打算把 Transformer 讲成百科全书,而是抓住一条主线:从一个 decoder-only Transformer block 出发,顺着前向计算一路走到推理瓶颈。
先把整体图景立起来
今天大模型推理里最常见的是 decoder-only Transformer。把它压缩成一条最核心的执行链,大致可以写成这样:
1 | token ids |
如果只看单层 block,它通常可以抽象成两步:
1 | h1 = x + Attention(Norm(x)) |
这两行式子很重要,因为后面几乎所有问题都能往这里落:
- Attention 在做什么。
- MLP 在做什么。
- 为什么要先 Norm 再做变换。
- 为什么要保留残差连接。
- 推理时计算和访存主要卡在哪一步。
一个 Transformer block 到底在算什么
Attention 负责“跨 token 混合信息”
Attention 的核心不是“公式很高级”,而是它让当前位置可以直接读取上下文里和自己最相关的信息。
写成公式是:
1 | Q = XWq |
直觉上可以这样理解:
Q表示当前 token 想找什么信息。K表示每个历史位置手里有什么“索引”。V表示真正要被取出的内容。
所以 Attention 做的事情其实很朴素:先算相关性,再按权重把历史信息汇总回来。
这也是 Transformer 和 RNN 的一个根本区别。RNN 要一格一格传递状态,Attention 可以直接在任意两个位置之间建立依赖。对长上下文来说,这件事非常关键。
MLP 负责“在单个 token 内部继续加工表示”
很多人会下意识觉得 Attention 是主角,MLP 只是配角。实际不是这样。
Attention 更像是在做“信息路由”,决定应该从上下文里看什么;MLP 则更像是在做“通道内加工”,把当前 token 的表示继续变换、拉伸、压缩、重组。
可以粗略记成一句话:
- Attention 解决“看哪里”。
- MLP 解决“怎么加工”。
在现代 LLM 里,MLP 往往不是最朴素的两层全连接,而是带门控的结构,比如 SwiGLU。这样做的目的,是让信息通过时更有选择性,表达能力也更强。
Norm 和残差不是陪衬,而是稳定系统的骨架
如果把 Attention 和 MLP 看成“变换器”,那 Norm 和残差就是“稳定器”。
Norm 的作用,是把每一层输入控制在一个更可管理的数值范围内。否则层数一深,激活分布漂移,训练和推理都会越来越不稳定。现在很多 LLM 更偏向 RMSNorm,一个现实原因就是它计算更简单,工程实现也更友好。
残差连接的意义则更直接:每一层不是把原表示推翻重来,而是在原表示上做增量修正。这样既能保留主干信息,也让深层网络更容易优化。
如果面试里要一句话说清楚,可以这么讲:
Attention 和 MLP 负责改变表示,Norm 和残差负责让这种改变可控、可叠加、可持续。
为什么同一个 Transformer,训练和推理看起来像两种系统
这是 AI Infra 里最关键的分界线之一。
训练时,模型面对的是完整已知序列。虽然目标仍然是 next-token prediction,但所有位置都已经在输入里了,所以很多计算可以并行做。
推理时则完全不同。模型只能基于当前已有前缀,一个 token 一个 token 往前生成。未来位置还不存在,因此很多本来可以一起展开的计算,必须拆成多次小步执行。
这就直接带来了两个阶段:
Prefill:先把整段 prompt 编进模型
当用户输入一长段 prompt 时,系统会先把这些 token 整体跑一遍。这个阶段序列长、矩阵大、计算密集,通常更偏向算力瓶颈。
Decode:之后每一步只生成一个新 token
进入生成阶段后,每次新 token 都要走完整个模型,但这时单次新增计算并不大,问题反而变成了:为了生成这一个 token,需要不断回头读取整段历史上下文的信息。
所以一个很实用的结论是:
Prefill 更像大矩阵计算问题,Decode 更像高频访存问题。
这也是为什么很多推理系统会把 TTFT 和 TPOT 分开看,因为它们对应的性能画像本来就不一样。
KV Cache 为什么是推理的关键
如果没有 KV Cache,模型每生成一个新 token,都要把历史序列从头算一遍 Attention。这在长上下文下几乎不可接受。
KV Cache 的思路很直接:历史 token 已经算出来的 K 和 V 不要丢,按层缓存起来。下一个 token 到来时,只计算它自己的新 Q/K/V,然后用当前 Q 去读取历史缓存里的 K/V。
也就是说,缓存的不是整个 block 的全部中间结果,而是 Attention 阶段最值得复用的那部分状态。
为什么只存 K/V,不存 Q
因为 Q 只为“当前这一步的查询”服务,它做完这一步就没有复用价值了。真正会在后续每一步都被重复访问的是历史 K/V。
所以 KV Cache 的本质不是一个“顺手优化”,而是把原本无法承受的重复计算,转换成了更可管理的显存占用和访存开销。
为什么 Decode 往往会变成 memory-bound
很多人第一次接触推理优化时,会默认以为大模型慢是因为“算得太多”。这句话只说对一半。
在 Decode 阶段,每一步新增的计算量并不一定夸张,但每一层都要去读很长的历史 K/V。上下文一长,这件事对显存带宽和访存模式的压力就会迅速上来。
因此 Decode 常见的真实瓶颈是:
- 不是 ALU 不够忙;
- 而是 GPU 在等数据;
- 不是矩阵乘本身算不动;
- 而是历史 KV 的读取代价越来越高。
这也是为什么很多 LLM 推理优化,到最后都会落到缓存布局、访存路径和调度策略上。
FlashAttention 和 PagedAttention,不是一回事
这两个名字经常被一起提,但它们解决的问题并不相同。
FlashAttention:减少 Attention 中间结果的显存读写
FlashAttention 的核心不是“换了一个公式”,而是重新组织计算顺序,尽量把中间结果留在更快的片上存储里,减少大规模的显存往返。
它主要解决的是:标准 Attention 在长序列下会产生大量中间张量,显存访问成本很高。
所以 FlashAttention 更偏向算子层优化,重点是让 Attention 这一步本身更高效。
PagedAttention:把 KV Cache 管理得更像虚拟内存
PagedAttention 的重点不在于单次 Attention 的数学计算,而在于怎么组织和管理 KV Cache。
如果把每个请求的 KV 都按连续大块显存分配,动态 batch 和不同长度请求混在一起时,很容易出现碎片化和空间浪费。PagedAttention 的思路,是把 KV Cache 拆成页,用更灵活的方式映射和复用。
所以它更偏向系统层优化,重点是让多请求场景下的缓存管理更稳定、更省显存。
可以把两者区别压缩成一句话:
FlashAttention 优化的是“怎么高效算 Attention”,PagedAttention 优化的是“怎么高效管 KV Cache”。
如果顺着一次前向计算往下看,瓶颈会落在哪里
站在推理视角,一个 block 的执行过程大致是这样的:
- 输入 hidden state 先做 Norm。
- 投影出当前步的
Q/K/V。 - 新产生的
K/V写入 KV Cache。 - 当前
Q去读取历史K/V,完成注意力计算。 - Attention 输出做线性投影并加上残差。
- 再做一次 Norm。
- 进入 MLP。
- MLP 输出再加上残差,送往下一层。
如果是 Prefill 阶段,Attention 和 MLP 往往都比较“像矩阵乘问题”,吞吐更受算子效率影响。
如果是 Decode 阶段,最敏感的地方通常变成了:
- KV Cache 写得是否规整;
- 历史
K/V读得是否高效; - 不同请求能不能顺畅拼 batch;
- 显存是不是被碎片和冗余缓存拖垮。
所以你会看到一个很典型的现象:很多人把 Transformer 结构背得很熟,但一谈推理优化就只会复述“用 FlashAttention、用 KV Cache”。真正有用的理解,是能把优化名词和执行路径一一对应起来。
站在 AI Infra 视角,真正要掌握到什么程度
如果你的目标是推理工程、模型部署或系统优化岗位,对 Transformer 的理解至少应该过四层:
第一层:能把 block 讲顺
知道输入如何经过 Norm、Attention、MLP 和残差,能解释每个模块在做什么。
第二层:能把训练和推理分开讲
明白为什么训练能并行,推理却要逐 token 展开;知道 Prefill 和 Decode 的差异。
第三层:能把瓶颈落到硬件行为
知道为什么有时是 compute-bound,有时是 memory-bound;知道 KV Cache 对显存和带宽意味着什么。
第四层:能把优化方法和问题本身对应起来
知道 FlashAttention 在解决算子层访存问题,PagedAttention 在解决缓存组织和显存管理问题,而不是把两个词当成“性能优化黑话”。
这四层一旦打通,很多系统问题会自然连起来。你再去看 continuous batching、prefix caching、speculative decoding,理解会顺很多,因为你已经知道推理系统到底在绕着哪些硬约束做取舍。
写在最后
Transformer 真正难的地方,不在于公式多,而在于它一头连着模型表达能力,一头连着真实推理系统。只盯着结构,很容易停在“我知道它由 Attention 和 MLP 组成”;只盯着优化,又容易把 FlashAttention、PagedAttention、KV Cache 这些词学成零散技巧。
更好的理解方式,是始终沿着同一条线往下走:
从一个 block 的前向计算开始,看到 Attention 如何读取上下文,看到 MLP 如何加工表示,看到推理为什么拆成 Prefill 和 Decode,看到 KV Cache 为什么成为显存核心,再看到 FlashAttention 和 PagedAttention 分别在解决哪类问题。
把这条线走通之后,Transformer 对你来说就不再是一张结构图,而是一套真正能落到推理系统上的执行逻辑。



