PyTorch 推理工程(08):批处理、KV Cache 与 Serving 视角

1. 本节定位

单机脚本中的 model(x) 仅覆盖推理链路的一小段。服务化场景下需同时考虑并发请求、变长输入、延迟与吞吐指标、显存中的 KV Cache,以及 Prefill 与 Decode 在计算与访存上的差异;执行后端也可能是 ORT、TensorRT、vLLM 等而非 Eager PyTorch。本篇从系统约束出发,将前文中的张量与测量概念映射到上述工程语境。


2. 推理系统优化的核心矛盾

真实推理系统不是单目标优化,而是在约束条件下的多目标平衡:

1
2
3
4
5
6
7
8
9
┌─────────────────────────────────────────────────────┐
│ 推理系统的五个核心维度 │
│ │
│ Throughput(吞吐) ──┐ │
│ Latency(延迟) ──┤ │
│ Memory(显存) ──┼── 需要在业务目标下综合权衡 │
│ Quality(输出质量)──┤ │
│ Cost(成本) ──┘ │
└─────────────────────────────────────────────────────┘

典型的 trade-off 关系:

优化动作 收益 代价
增大 batch size 吞吐↑、GPU 利用率↑ 单请求延迟↑、显存↑
开启低精度(FP16) 速度↑、显存↓ 数值稳定性需验证
增大 KV cache 容量 并发数↑、decode 速度↑ 显存↑
严格 latency SLA 用户体验↑ 吞吐↓(不能等太多请求凑批)

所以真正的 inference 工程不是"找一个最快的设置",而是在特定业务目标下找到最合适的方案。


3. LLM 推理的两阶段模型

自回归大模型从 prompt 到完整回复的执行过程,通常可划分为两个阶段,其计算与访存特征不同:

1
2
3
4
5
6
7
8
9
10
11
12
用户输入:"请解释 KV Cache 是什么,并举例说明"(假设 20 个 token)

阶段 1:Prefill(1 次大计算)
→ 把这 20 个 token 一次性喂进模型
→ 计算所有 20 个 token 的 attention
→ 建立好历史 Key/Value 缓存

阶段 2:Decode(反复小计算)
→ 生成 token 1:"KV" ← 只计算 1 个新 token
→ 生成 token 2:"Cache" ← 只计算 1 个新 token
→ 生成 token 3:"是" ← 只计算 1 个新 token
→ ... 重复直到生成完整回答(可能 200 个 token)

理解了这个,后面所有内容都能挂上去。


4. 什么是 Prefill,为什么它更好优化

Prefill 做什么

把整段已有输入 prompt 一次性喂进模型,计算出所有 token 的中间表示,并把历史 Key/Value 缓存建立好。

Prefill 的计算特征

假设 prompt 长度为 L 个 token,Self-Attention 的计算量:

1
2
3
4
Attention 计算量 ∝ L² × d_model

L=512 时:512² = 262,144 次操作
L=2048 时:2048² = 4,194,304 次操作(16 倍!)

Prefill 的特点:

  • 大块矩阵计算L×L 的 attention)
  • 计算密集,GPU 容易吃满
  • 吞吐敏感:一次大计算,受益于大 batch
  • 相对好优化:大连续矩阵乘,低精度收益明显

Prefill 性能定性示意

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import time

# 模拟 prefill:一次性输入 L 个 token
def prefill_sim(L, d=512, batch=1):
"""简化模拟:矩阵乘代表 QK attention 计算"""
Q = torch.randn(batch, L, d, device="cuda")
K = torch.randn(batch, L, d, device="cuda")

torch.cuda.synchronize()
t0 = time.time()

# Q @ K^T 是 attention 的核心计算
attn = torch.bmm(Q, K.transpose(1, 2)) / (d ** 0.5)

torch.cuda.synchronize()
return (time.time() - t0) * 1000

if torch.cuda.is_available():
for L in [128, 512, 1024, 2048]:
t = prefill_sim(L)
print(f"L={L:4d}: {t:.2f} ms (计算量 ∝ L² = {L**2:,})")

输出(示意):

1
2
3
4
L= 128:  0.05 ms  (计算量 ∝ L² =    16,384)
L= 512: 0.15 ms (计算量 ∝ L² = 262,144)
L=1024: 0.48 ms (计算量 ∝ L² = 1,048,576)
L=2048: 1.82 ms (计算量 ∝ L² = 4,194,304)

prefill 阶段计算量随序列长度近似平方增长,故长 prompt 的 prefill 成本显著偏高。


5. 什么是 Decode,为什么它更难优化

Decode 做什么

每次只新增 1 个 token,但需要重复 output_length 次:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Decode 的逻辑(极简示意)
def decode_loop(model, prompt_tokens, max_new_tokens=100):
tokens = prompt_tokens.copy()

for step in range(max_new_tokens):
# 每次只往前走一步:只计算最新 token 和历史的 attention
x = torch.tensor([tokens], device="cuda")

with torch.inference_mode():
logits = model(x) # 实际上利用了 KV cache

next_token = logits[0, -1].argmax().item()
tokens.append(next_token)

if next_token == EOS_TOKEN:
break

return tokens

Decode 的计算特征

在每一步 decode 中,新 token 只需要和历史 K/V 做 attention:

1
2
3
新 token(1 个)的 attention 计算量 ∝ 1 × L_history × d_model

L_history=512 时:1 × 512 次操作(远小于 prefill 的 512²!)

Decode 的特点:

  • 每步计算量小(不是 而是 1×L 量级)
  • 步数多(生成 200 token 就要循环 200 次)
  • 延迟敏感:每一步是一个等待轮次,用户能感受到逐 token 速度
  • 难优化:每步矩阵太小,GPU 计算单元不容易吃满

Prefill vs Decode 对比

特征 Prefill Decode
输入量 整段 prompt(L token) 每步只有 1 个新 token
计算量/次 L² × d(大) 1 × L_history × d(小)
执行次数 1 次 output_length 次(几十~几百次)
GPU 利用率 容易吃高 难吃满(单步太小)
优化重点 吞吐、大块计算 延迟、KV cache 效率
是否适合大 batch 有限(受显存限制)

6. KV Cache:用显存换掉重复计算

这是 LLM inference 最核心的概念之一,必须理解透。

如果没有 KV Cache

在 decode 阶段,每生成一个新 token,理论上需要重新计算所有历史 token 的 Key 和 Value:

1
2
3
4
5
6
7
8
9
10
11
12
13
步骤 1:生成 token_1
→ 计算 token_1 与 [prompt_all] 的 attention
→ 完成

步骤 2:生成 token_2
→ 计算 token_2 与 [prompt_all + token_1] 的 attention

[prompt_all] 的 K/V 上一步不是算过了吗?白算了!

步骤 3:生成 token_3
→ 计算 token_3 与 [prompt_all + token_1 + token_2] 的 attention

[prompt_all + token_1] 的 K/V 上两步都算过了!每步都白算!

不用 KV Cache 的代价:随着 decode 步骤增加,重复计算越来越多,decoder 越来越慢。

有了 KV Cache

在 Prefill 阶段计算完 prompt 的 K/V 后,把它们存进显存,之后 decode 每步只需要:

  1. 计算新 token 的 Q、K、V
  2. 把新 token 的 K/V 追加到缓存里
  3. 用新 token 的 Q 和全部历史缓存的 K/V 做 attention
1
2
3
4
5
6
7
Prefill:计算 prompt 的 K/V → 存入 KV Cache

Decode 步骤 N:
新 token 的 Q → attend 到 KV Cache(历史所有 K/V)

直接读缓存,不重复算!
新 token 的 K/V → 追加到 KV Cache 末尾

KV Cache 的显存消耗:有量化感才能理解系统约束

KV Cache 的大小是多维的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
KV Cache 大小 = 2           # K 和 V 各一份
× num_layers # 每一层都有自己的 KV
× num_heads # 多头 attention
× seq_len # 历史序列长度
× head_dim # 每个 head 的维度
× dtype 字节数 # float16 = 2 字节

以 GPT-2(小模型)为例:
layers=12, heads=12, head_dim=64, float16
seq_len=1024, batch=1

KV Cache = 2 × 12 × 12 × 1024 × 64 × 2 字节
= 2 × 12 × 12 × 1024 × 64 × 2
= 37,748,736 字节
≈ 36 MB

真实 LLM(如 LLaMA-7B):
layers=32, heads=32, head_dim=128, float16
seq_len=4096, batch=1

KV Cache = 2 × 32 × 32 × 4096 × 128 × 2
≈ 2 GB!(仅 1 个请求)

KV 体量可按层数、头数、序列长度与 dtype 估算,例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def estimate_kv_cache_mb(
num_layers, num_heads, head_dim, seq_len, batch_size,
dtype_bytes=2 # float16 = 2 bytes
):
"""估算 KV Cache 显存占用(MB)"""
size = (
2 # K and V
* num_layers
* num_heads
* seq_len
* head_dim
* batch_size
* dtype_bytes
)
return size / 1024**2

# GPT-2
print(f"GPT-2 (seq=1024, batch=1): {estimate_kv_cache_mb(12, 12, 64, 1024, 1):.1f} MB")
print(f"GPT-2 (seq=1024, batch=8): {estimate_kv_cache_mb(12, 12, 64, 1024, 8):.1f} MB")

# LLaMA 7B 估算
print(f"LLaMA-7B (seq=4096, batch=1): {estimate_kv_cache_mb(32, 32, 128, 4096, 1):.0f} MB")
print(f"LLaMA-7B (seq=4096, batch=8): {estimate_kv_cache_mb(32, 32, 128, 4096, 8):.0f} MB")
print(f"LLaMA-7B (seq=4096, batch=32): {estimate_kv_cache_mb(32, 32, 128, 4096, 32):.0f} MB")

输出:

1
2
3
4
5
GPT-2 (seq=1024, batch=1):        36.0 MB
GPT-2 (seq=1024, batch=8): 288.0 MB
LLaMA-7B (seq=4096, batch=1): 2048 MB ← 仅 1 个请求就占 2GB!
LLaMA-7B (seq=4096, batch=8): 16384 MB ← 8 个并发请求占 16GB!
LLaMA-7B (seq=4096, batch=32): 65536 MB ← 32 个并发请求占 64GB!远超显存

这个数字解释了一切

  • 为什么大模型并发数受限(显存被 KV Cache 吃满)
  • 为什么上下文越长越贵(KV Cache 线性增长)
  • 为什么系统要做 KV Cache 管理(不能无限堆叠)
  • 为什么量化(INT8/INT4 KV)很重要(可以把 Cache 压到原来 1/2 ~ 1/4)

7. KV Cache 带来的工程问题

KV Cache 省了计算,但引入了一系列工程问题:

问题 1:显存碎片化

如果每个请求的序列长度不同,KV Cache 的内存需求也不同。用固定块分配容易碎片化,导致实际能服务的请求比理论少。

解决方案Paged Attention(vLLM 的核心思路)——把 KV Cache 分成固定大小的 Page(类似操作系统的内存页),按需分配,支持非连续存储。

问题 2:多个请求共享 Prompt 的重复计算

如果很多用户都问同一个系统 prompt(比如同一个角色扮演的开场),每次 prefill 都要重算这部分 KV,浪费。

解决方案Prefix Caching——把相同 prefix 的 KV Cache 共享复用。

问题 3:Cache 回收与调度

一个请求完成后,它的 KV Cache 要释放;新请求来了要分配;显存不够时要决定拒绝还是等待。

这已经变成了类似操作系统内存管理的问题。


8. Batching:让 GPU 不浪费在"零散请求"上

为什么需要 Batching

GPU 是"喜欢大块连续计算"的硬件。如果每来一个请求就单独跑一次 forward,GPU 利用率可能只有个位数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
import time

device = "cuda" if torch.cuda.is_available() else "cpu"
model = nn.Linear(1024, 1024).to(device).eval()

def measure_throughput(batch_size, n_runs=200):
x = torch.randn(batch_size, 1024, device=device)

with torch.inference_mode():
for _ in range(20): # warmup
model(x)

torch.cuda.synchronize()
t0 = time.time()

with torch.inference_mode():
for _ in range(n_runs):
model(x)

torch.cuda.synchronize()
elapsed = time.time() - t0

latency_ms = elapsed / n_runs * 1000
throughput = batch_size * n_runs / elapsed
return latency_ms, throughput

if torch.cuda.is_available():
print(f"{'batch':>6} {'latency':>10} {'throughput':>18}")
for bs in [1, 4, 16, 64, 256]:
lat, tput = measure_throughput(bs)
print(f"{bs:>6} {lat:>8.2f}ms {tput:>14.0f} samples/s")

典型输出(示意,实际值取决于 GPU 型号):

1
2
3
4
5
6
batch    latency         throughput
1 0.08ms 12,500 samples/s
4 0.09ms 44,400 samples/s
16 0.11ms 145,000 samples/s
64 0.25ms 256,000 samples/s
256 0.90ms 284,000 samples/s

关键观察

  • batch 从 1 → 64:延迟只增加了 3 倍,但吞吐增加了 20 倍
  • batch 从 64 → 256:延迟增加 3.6 倍,吞吐只增加 1.1 倍(收益递减,GPU 已吃满)

9. Static Batching 的局限

最简单的 batching 策略是 Static Batching:等凑够 N 个请求,再一起跑。

1
2
3
4
请求队列:[A, B, C, D, E, F, G, H](batch=4)

第 1 批:[A, B, C, D] → 一起 prefill + decode 直到全部完成
第 2 批:[E, F, G, H] → 再一起运行

问题:不同请求的生成长度不同。

1
2
3
4
5
6
7
8
9
      time →
批次 [A, B, C, D]:
A: ████████████░░░░░░ ← A 生成了 60 token 就完了,但在等 D
B: ████████░░░░░░░░░░ ← B 生成了 40 token 就完了,在等 D
C: ████████████████░░ ← C 生成了 80 token 就完了,在等 D
D: ████████████████████ ← D 生成了 100 token

A、B、C 完成后,GPU 有大量时间在空转等 D
这些 ░ 时间是纯浪费的 GPU 时间

10. Continuous Batching:动态管理 Batch

Continuous Batching(也叫 in-flight batching)的思路:

不把一批请求绑定到底,而是当某个请求完成时,立刻让新请求填进来。

1
2
3
4
5
6
7
     time →
槽位 1:[A A A A A A A ][E E E E E E E E E E E]
槽位 2:[B B B B ][F F F F F F F F F ]
槽位 3:[C C C C C C ][G G G G G G G ]
槽位 4:[D D D D D D D D D D D D][H H H H]

当 B 完成时(槽位 2),立刻把 F 插进来,而不是等 A/C/D 都完成

收益:GPU 利用率大幅提升,不再有大块的等待空洞。

这也是 vLLM、TGI(Text Generation Inference)等现代 LLM 推理框架的核心特性。


11. Padding 的浪费:为什么长度对齐有代价

在一个 batch 里,不同请求的序列长度不同,但 GPU 需要整齐的矩阵计算。通常的解法是 padding:把短序列补 0 到最长序列的长度。

1
2
3
4
5
6
7
8
9
10
11
batch 内的请求:
request A: [token1, token2, token3] 长度 3
request B: [token1, token2, token3, token4, token5] 长度 5
request C: [token1, token2] 长度 2

Padding 后(对齐到长度 5):
request A: [token1, token2, token3, PAD, PAD] ← 2 个 PAD
request B: [token1, token2, token3, token4, token5]
request C: [token1, token2, PAD, PAD, PAD] ← 3 个 PAD

这些 PAD 位置的计算是无效的,纯浪费

浪费比例 = (总 padding 数) / (总 token 数)

如果请求长度差异很大,浪费可能非常严重。

解决思路

  • 排序和分组:把相近长度的请求尽量分到同一批
  • Bucket batching:把请求按长度分桶,每桶内 padding 浪费最小
  • Flash Attention / Variable-Length Attention:某些算子实现可以支持 packed(无 padding)格式

12. 推理系统的请求生命周期

把所有概念串起来:一个请求的完整生命周期:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
用户发送请求

① 排队(Queue)
- 系统决定何时处理这个请求
- 可能等待凑批

② 预处理(Preprocessing)
- Tokenization
- CPU 侧处理
- 数据搬到 GPU(CPU→GPU,见第 04 篇)

③ Prefill
- 整段 prompt 一次性 forward
- 建立 KV Cache
- 计算量大,吞吐敏感

④ Decode 循环(重复直到 EOS)
- 每步生成 1 个 token
- 利用 KV Cache,只算新 token 的 attention
- 延迟敏感

⑤ 后处理(Postprocessing)
- Detokenization
- 结果返回给用户

⑥ Cache 回收
- 释放这个请求占用的 KV Cache
- 供下一个请求使用

用户感受到的 latency = 排队时间 + 预处理时间 + prefill 时间 + decode 时间 + 后处理时间

结论benchmark 得到的 model(x) 时间通常仅对应流水线中设备前向一段;端到端延迟还包含前后处理与调度等环节。


13. 用代码理解 Batch Size 对 Latency 和 Throughput 的影响

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn as nn
import time

device = "cuda" if torch.cuda.is_available() else "cpu"

# 模拟一个简化的 LLM transformer block
class SimpleBlock(nn.Module):
def __init__(self, d_model=512, n_heads=8):
super().__init__()
self.attn_qkv = nn.Linear(d_model, 3 * d_model)
self.attn_out = nn.Linear(d_model, d_model)
self.ffn1 = nn.Linear(d_model, 4 * d_model)
self.ffn2 = nn.Linear(4 * d_model, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)

def forward(self, x):
# 简化 attention(不做真正的 masked attention,仅演示)
qkv = self.attn_qkv(x)
out = self.attn_out(qkv[:, :, :x.shape[-1]])
x = self.norm1(x + out)
ffn = self.ffn2(torch.relu(self.ffn1(x)))
return self.norm2(x + ffn)

model = SimpleBlock().to(device).eval()

# 模拟 decode 场景:batch=N,seq_len=1(每步只新增 1 token)
# 和 prefill 场景:batch=1,seq_len=L(一次性处理整段)

def measure(batch, seq_len, n_runs=100):
x = torch.randn(batch, seq_len, 512, device=device)
with torch.inference_mode():
for _ in range(10):
model(x)
torch.cuda.synchronize()
t0 = time.time()
with torch.inference_mode():
for _ in range(n_runs):
model(x)
torch.cuda.synchronize()
elapsed = (time.time() - t0) / n_runs * 1000
token_throughput = batch * seq_len * n_runs / (time.time() - t0 + elapsed * n_runs / 1000)
return elapsed, batch * seq_len * 1000 / elapsed # latency, tokens/s

if torch.cuda.is_available():
print("=== Decode 场景(seq_len=1,模拟每步只有 1 个新 token)===")
print(f"{'batch':>6} {'latency':>10} {'tokens/s':>12}")
for bs in [1, 4, 16, 64]:
lat, tput = measure(bs, seq_len=1)
print(f"{bs:>6} {lat:>8.2f}ms {tput:>10.0f}")

print("\n=== Prefill 场景(batch=1,seq_len=L,一次性处理 prompt)===")
print(f"{'seq_len':>8} {'latency':>10} {'tokens/s':>12}")
for L in [64, 256, 1024, 2048]:
lat, tput = measure(batch=1, seq_len=L)
print(f"{L:>8} {lat:>8.2f}ms {tput:>10.0f}")

14. 为什么真实岗位不是直接跑原生 PyTorch Eager

原生 eager 的优势是灵活好调试。但真实线上更关心:

需求 PyTorch Eager 专用推理后端
高并发 dynamic batching 需要手动实现 内置支持
KV Cache 管理 需要手动实现 优化内置
paged KV / prefix caching 不支持 vLLM 等支持
极致算子性能(如 FlashAttention) 可以调用 深度集成
多机多卡推理 需要额外工程 内置支持
跨语言服务接口 需要额外封装 内置 HTTP 接口

所以常见路径是:

1
2
3
4
5
6
开发 / 训练(PyTorch eager)

验证和性能分析(PyTorch profiler / benchmark)
↓ 根据目标选择
torch.compile() ONNX Runtime / TensorRT vLLM / TGI 等
(PyTorch 内提速) (传统 CNN 部署) (LLM 服务化推理)

工程上宜了解各后端在图格式、动态 shape 与调度上的差异,例如:

PyTorch 是模型的"源表达"和"验证入口",不一定是最终线上形态。学透 PyTorch 是理解所有后端的基础。


15. Tail Latency(尾延迟):为什么平均值不够看

服务化推理里,平均延迟不是全部。SLA 通常用 P99 描述:

“99% 的请求在 100ms 内完成” → P99 latency ≤ 100ms

如果有 1% 的请求因为排队太久、batch 太大、或者遇到长序列,延迟飙到 5 秒,即使平均 latency 很漂亮,这个系统也不合格。

推理系统里的 Tail Latency 来源

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 简单模拟等待时间对尾延迟的影响
import numpy as np

np.random.seed(42)
n_requests = 1000

# 模拟两种调度策略
# 策略 1:等凑满 batch=16 再执行(等待时间最多 50ms)
# 策略 2:来了就跑(batch=1,无等待)

compute_time = np.random.exponential(scale=20, size=n_requests) # 均值 20ms

# 策略 1
wait_time_1 = np.random.uniform(0, 50, n_requests)
total_1 = compute_time + wait_time_1

# 策略 2
total_2 = compute_time # 无等待,但吞吐低

print("策略 1(batch=16,等待最多 50ms):")
print(f" 平均: {total_1.mean():.1f}ms P95: {np.percentile(total_1, 95):.1f}ms P99: {np.percentile(total_1, 99):.1f}ms")

print("\n策略 2(batch=1,立即执行):")
print(f" 平均: {total_2.mean():.1f}ms P95: {np.percentile(total_2, 95):.1f}ms P99: {np.percentile(total_2, 99):.1f}ms")

输出(示意):

1
2
3
4
5
策略 1(batch=16,等待最多 50ms):
平均: 45.2ms P95: 67.3ms P99: 74.8ms

策略 2(batch=1,立即执行):
平均: 19.8ms P95: 57.2ms P99: 79.4ms

这个例子说明:吞吐和延迟的 trade-off 在不同百分位上表现不同,需要根据业务 SLA 选择策略。


16. 从张量语义到系统约束

以下将前文中的 shape、显存与调度概念与 LLM 服务中的典型瓶颈对应;数字仅为示意,实际占用随模型配置与实现变化。

KV Cache 与并发上限

KV Cache 为每个进行中的请求按层保存历史 Key/Value,体量可近似为 2 × num_layers × num_heads × seq_len × head_dim × dtype_bytes(再乘并发数)。在固定显存预算下,该线性项常先于算力成为并发瓶颈;PagedAttention 等方案通过分页与碎片管理提高显存利用率,属于在同一预算下提升有效并发的工程手段。

Decode 相对 Prefill 的优化难点

Decode 每步仅扩展一个 token,单步 GEMM 规模小,硬件利用率常低于 Prefill 的大块矩阵乘;多步循环叠加调度与内核启动开销。批内各请求生成长度不一易造成空转,因此常与连续批处理、针对小批 attention 的算子优化等配合使用。


17. 常见问题(技术要点)

Q:Prefill 和 Decode 的区别是什么?
A:Prefill 是把整段 prompt 一次性送进模型计算,计算量大(∝ L²),是大块矩阵乘,GPU 容易吃满,偏吞吐型。Decode 是每步只新增 1 个 token 的生成循环,单步计算量小(∝ 1×L),需要反复执行 output_length 次,偏延迟敏感,更依赖 KV Cache 和调度效率。

Q:KV Cache 的本质是什么?它的代价是什么?
A:把每一层历史 token 的 Key 和 Value 缓存在显存里,避免 decode 时重复计算历史 attention,本质是"用显存换计算"。代价是显存消耗与层数、头数、序列长度、并发数成正比,大模型下单请求 KV Cache 就能占几 GB,严重限制并发能力。

Q:KV Cache 大小怎么估算?
A:2 × num_layers × num_heads × seq_len × head_dim × dtype_bytes × batch_size。以 LLaMA-7B(32层、32头、128 head_dim)、4096 token、FP16 为例:2 × 32 × 32 × 4096 × 128 × 2 ≈ 2GB,8 个并发就是 16GB。

Q:Static batching 和 Continuous batching 的区别?
A:Static batching 把一批请求绑定在一起从头跑到尾,Decode 步数较少的请求完成后 GPU 空转等较慢的请求。Continuous batching 允许请求完成后立刻从队列里取新请求填充,GPU 利用率更高,是现代 LLM 推理框架的标配。

Q:为什么 GPU batch size 越大不总是越好?
A:增大 batch 能提高吞吐和 GPU 利用率,但会增加等待凑批时间(延迟↑)、padding 浪费(长度不同的请求混批)、KV Cache 显存压力(并发数↑),在延迟 SLA 严格的场景里反而有害。

Q:为什么 PyTorch 基础对推理系统工程很重要?
A:所有推理后端的模型都来自 PyTorch,行为对齐、性能复现、问题定位通常都从 PyTorch start。KV Cache 的存储布局、dtype 选择、batch 的 Tensor 组织、profiling 发现瓶颈——这些都要落回 Tensor 操作层面。学透 PyTorch 是理解所有后端的基础。


18. 思考题

练习 1:动手算 KV Cache 显存

用本节的 estimate_kv_cache_mb() 函数,计算以下场景的 KV Cache 大小:

  • GPT-2 medium(24层、16头、head_dim=64,seq=2048,batch=16)
  • LLaMA-7B(32层、32头、head_dim=128,seq=4096,batch=8)
  • LLaMA-13B(40层、40头、head_dim=128,seq=4096,batch=4)

思考:假设 GPU 有 40GB 显存,模型自身占去一半,每个场景最多能同时服务多少个请求?

练习 2:Batch Size vs Latency/Throughput 实验

运行第 13 节的 measure_throughput 函数,测出不同 batch size 的 latency 和 throughput,画出趋势。

思考:在哪个 batch 区间吞吐增长明显放缓?与 GPU 算力/带宽规格如何对应?

练习 3:模拟 Prefill 的 O(L²) 特性

运行第 4 节的 prefill_sim 函数,验证 attention 计算时间随序列长度 L 的增长趋势。

思考:如果 prefill 时间确实是 ∝ L²,那么 L 从 512 增加到 2048(4 倍),时间应该增加几倍?实测结果符合吗?

练习 4:写一个简单的请求调度模拟

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import time
import random
import threading
from queue import Queue

# 模拟一个极简的 static batch 调度器
def static_batch_scheduler(request_queue, batch_size=4, model_latency_ms=50):
"""
batch_size:等凑满多少个请求才发出
model_latency_ms:假设每 batch 计算耗时固定
"""
results = []
batch = []

while True:
req = request_queue.get()
if req is None:
break
batch.append(req)

if len(batch) >= batch_size:
time.sleep(model_latency_ms / 1000) # 模拟计算
for r in batch:
end_time = time.time()
results.append(end_time - r["arrive_time"])
batch = []

return results

# 模拟请求到来
q = Queue()
for i in range(20):
time.sleep(random.uniform(0.005, 0.02)) # 随机到达
q.put({"id": i, "arrive_time": time.time()})
q.put(None) # 结束信号

改造:实现一个"超时就发"的版本——如果等了 20ms 还没凑满 batch,就把当前的送出去。观察对 tail latency 的影响。

练习 5:把 KV Cache 概念和 Tensor 操作对应起来

用纯 PyTorch 写一个极简的 KV Cache append 操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch

# 模拟 KV Cache 的追加(decode 每步的操作)
class SimpleKVCache:
def __init__(self, max_seq_len, d_model):
self.k_cache = torch.zeros(1, max_seq_len, d_model)
self.v_cache = torch.zeros(1, max_seq_len, d_model)
self.cur_len = 0

def append(self, new_k, new_v):
"""追加新 token 的 K/V"""
# new_k shape: (1, 1, d_model) - 当前新 token 的 K
pos = self.cur_len
self.k_cache[:, pos:pos+1, :] = new_k
self.v_cache[:, pos:pos+1, :] = new_v
self.cur_len += 1

def get(self):
"""获取所有历史 K/V"""
return self.k_cache[:, :self.cur_len, :], self.v_cache[:, :self.cur_len, :]

# 测试
cache = SimpleKVCache(max_seq_len=512, d_model=64)
for step in range(5):
new_k = torch.randn(1, 1, 64)
new_v = torch.randn(1, 1, 64)
cache.append(new_k, new_v)
k, v = cache.get()
print(f"Step {step+1}: KV Cache shape = {k.shape}") # (1, step+1, 64)

思考:真实系统里,KV Cache 会跨层(num_layers)、跨头(num_heads),shape 应该是什么?每步 append 后 shape 如何变化?


19. 本节要点与自检

  • 清晰解释 Prefill 和 Decode 的计算差异(计算量、执行次数、优化重点各不同)
  • 能估算 KV Cache 的显存大小,并解释它如何限制并发能力
  • 理解为什么 decode 比 prefill 难优化(矩阵太小,GPU 难吃满)
  • 知道 Static Batching 的空转问题和 Continuous Batching 的解决思路
  • 理解 Padding 浪费的来源,知道为什么变长序列 batching 更复杂
  • 能把 PyTorch 的 Tensor/shape/device/dtype 知识联系到推理系统问题上
  • 能用显存与 shape 估算将 KV Cache、并发与批调度串成可核对表述

20. 小结

系统设计讨论通常要求将单请求路径上的 shape、KV 显存、批调度与 Prefill/Decode 差异一并纳入,而非仅描述单次 model(x) 调用。


系列导航