mini-infer系统实战-11-Chunked Prefill:长 prompt 为什么会饿死并发请求

本文是 mini-infer 系列第 9 篇。前 8 篇依次实现了 Paged KV Cache、True PagedAttention(flash_attn)、Preemption 调度和 OpenAI 兼容 HTTP API。这一篇关注一个在服务化场景中容易被忽视的延迟问题:当一个长 prompt 到达时,正在 decode 的其他请求会被完全卡住多久?

问题:prefill 是个"大礼包"

Phase 8 的引擎在处理一个新请求时,会把整个 prompt 一次性做完 prefill,然后进入 decode 循环。对于短 prompt(32 tokens)来说,一次 prefill 耗时约 10ms,几乎感知不到。但对于 1024 tokens 的长 prompt,单次 prefill 耗时约 138ms。

这 138ms 的问题在于:prefill 期间,所有正在 decode 的请求都停摆了。它们拿不到 GPU 时间,无法生成新 token。用户感受到的就是流式输出突然停了 138ms,然后继续。

用数字说:7 个短请求正在流式 decode,每步约 18ms 出一个 token。此时一个 1024-token 长请求到达。在无分块情况下,这 7 个请求的下一个 token 会延迟整整 138ms 才出现——这是正常步长的 7.6 倍。这就是 ITL spike(token 间隔峰值)。

方案:把大块切成小块

Chunked Prefill 的思路很直接:不一次把 1024 个 token 全做完,而是每步只做 chunk_size 个 token,剩下的留到下一步。这样每步的 prefill 代价从 138ms 降到 138ms × (chunk_size / 1024),decode 请求在两次 chunk 之间有机会插入运行。

直觉上等价于:原来是"我先用 138ms 把这条路全堵死,你们等着",现在变成"我每步多占一点时间,但你们每步都能运行"。

这里有一个显而易见的 trade-off:长请求的 TTFT 必然增加。原来 138ms 出第一个 token,现在分 8 步(1024/128)完成 prefill,每步还要和 decode 竞争,TTFT 可能增至 400ms。这不是 bug,是设计内的权衡。

实现:在调度器里加一个 PREFILLING 状态

新增状态机节点

原来请求的生命周期是:WAITING → RUNNING → FINISHED。
Phase 9 在中间插入了一个新状态:

1
WAITING → PREFILLING → RUNNING → FINISHED

PREFILLING 是一个独立队列(_prefilling: dict[str, RequestState]),一次最多 1 个请求在里面。每步 engine 推进它一个 chunk,最后一个 chunk 完成后立即移入 RUNNING,当步参与 decode_batch。

1
2
3
4
5
6
7
8
9
10
11
12
# scheduler.py — 新增 7 个方法,核心是这三个
def add_to_prefilling(self, state: RequestState) -> None:
self._prefilling[state.request.request_id] = state

def move_prefilling_to_running(self, state: RequestState) -> None:
self._prefilling.pop(state.request.request_id, None)
self._running[state.request.request_id] = state

def get_next_prefilling(self) -> RequestState | None:
if self._prefilling:
return next(iter(self._prefilling.values()))
return None

DynamicCache 的跨 chunk 累积

每个 chunk 只处理 prompt_token_ids[t_start:t_end],但 attention 需要看到所有前缀的 KV。解决方案是利用 HuggingFace 的 DynamicCache

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# model_runner.py — prefill_chunk()
def prefill_chunk(self, state, token_start, token_end, past_cache, is_last_chunk):
state.prefilled_tokens = token_end # 更新进度

chunk_ids = state.prompt_token_ids[token_start:token_end]
input_ids = torch.tensor([chunk_ids], device=self.config.device)

if past_cache is None:
past_cache = DynamicCache()

with torch.no_grad():
out = self.model(input_ids=input_ids, past_key_values=past_cache, use_cache=True)

if is_last_chunk:
# 所有 chunk 的 KV 都在 out.past_key_values 里了
self.kv_cache.write_prefill_kv(state.request.request_id, out.past_key_values)
next_token_id = _sample_token(out.logits[0, -1], state.request.sampling_params)
state.append_generated(next_token_id, "")
state.prefilled = True
return None
else:
return out.past_key_values # 中间状态,下次 chunk 继续用

关键设计:非最后 chunk 返回积累的 DynamicCache(存在 CPU 内存 _prefilling_caches 里),最后一个 chunk 调用 write_prefill_kv 把完整 KV 写入 GPU block tensor,然后请求进入 decode 路径。这样整个 chunking 过程对 block tensor 和 decode 路径完全透明。

engine 主循环的改动

chunk_prefill_size > 0 时,engine 的每步逻辑变成:

1
2
3
4
5
6
1. 若 PREFILLING 空 + WAITING 有请求 + batch 未满 → 准入一个到 PREFILLING
2. 推进 PREFILLING 一个 chunk
- 非最后 chunk:保存中间 DynamicCache,继续下一步
- 最后 chunk:写 KV、移入 RUNNING
3. 对所有 RUNNING 请求做 decode_batch(这步包含刚完成 prefill 的请求)
4. 清理完成请求,尝试换入 swapped 请求

步骤 3 是关键:decode_batch 和 chunked prefill 在同一步内完成,短请求不会跳过任何一步。

一个不起眼的 bug:just_prefilled 变量

实现中有个需要小心的细节。decode 完成后,需要把本步新生成的 token 返回给调用方。用的是"增量 decode"逻辑:new_text = decode(all_ids) - decode(all_ids[:pre_len])

对于刚完成 prefill 进入 RUNNING 的请求,它的 generated_token_ids 里已经有一个 prefill 采样的 token(来自 prefill_chunk 最后一步),然后 decode_batch 又生成了第二个 token。如果 pre_len = 1,只会返回第二个 token;实际上第一个 token(prefill 产物)也应该在这一步返回。

所以需要对刚完成 prefill 的请求,把 pre_len 强制置 0:

1
2
3
4
5
6
7
8
9
10
11
12
# 原始路径(chunk_prefill_size=0)用 newly_admitted
# chunked 路径需要同样的机制
just_prefilled: list[RequestState] = []

# ... chunked 分支:
if is_last:
self.scheduler.move_prefilling_to_running(pf_state)
just_prefilled.append(pf_state)

# ... 收集 token 前:
for s in just_prefilled:
pre_lens[s.request.request_id] = 0 # 从 0 开始,捕获 prefill token

这个变量统一了 chunked 和非 chunked 两条路径,避免了重复逻辑。

坑:benchmark 连写三遍才对

这是这次实现里花时间最多的部分,记录一下三次返工的原因,希望对做同类工作的人有参考价值。

第一版:场景设计反了

最初设计:7 个短请求在 WAITING,同时提交 1 个长请求,看谁先完成 prefill。结果是:chunk=0 时长请求独占一步,短请求 TTFT 略快;chunk=128 时长请求分 8 步才完成 prefill,短请求反而等更久——效果是负的。

问题在于:Chunked Prefill 解决的不是"哪个请求先完成 prefill",而是"已经在 decode 的请求不被新到的长请求卡住"。

正确场景:先让 7 个短请求完成 prefill 进入 decode(warmup 阶段),再提交长请求,然后测量短请求的 ITL spike

第二版:时间戳记录在 step() 之前

修正场景之后,测到所有 TTFT ≈ 0ms。查了一下,发现记录时间戳的代码在 engine.step() 调用之前:

1
2
3
4
5
6
7
# 错误:
t_now = time.perf_counter() # ← 在 step 前!
new_tokens = engine.step()

# 正确:
new_tokens = engine.step()
t_now = time.perf_counter() # ← 在 step 后

第三版:max ITL 的跨阶段间隔被漏掉

即使时间戳对了,最大 ITL 还是不准。原因:max_itl_spike 只统计了 phase 3(长请求到达后)的 token 间隔,但真正最大的间隔是"warmup 最后一个 token 到 phase 3 第一个 token"——这个间隔横跨了整个 prefill 过程,正是 chunked prefill 要压缩的目标。

修复:全程记录所有短请求的 token 时间戳(warmup 阶段也记),然后在完整时间序列上找包含 t_long_start 跨越间隔的最大值。

实验结果

环境:RTX 4090,Qwen2.5-7B-Instruct float16,7 个 ~32-token 短请求 decode 中 + 1 个 ~1024-token 长请求到达。

chunk_size 短请求 ITL spike 变化 长请求 TTFT 总吞吐
0(基准) 138.4 ms 138.4 ms 290 tok/s
128 45.9 ms −66.8% 399.3 ms 297 tok/s
256 58.6 ms −57.5% 269.8 ms 301 tok/s

两组均无吞吐回归。chunk=256 是相对均衡的选择:ITL spike 降低 57%,TTFT 增加约 1 倍(138→270ms),适合大多数场景。chunk=128 在 ITL 平稳性上更激进,但长请求 TTFT 增至近 3 倍。

为什么吞吐没有下降? 每步 decode_batch 的代价不变(batch 大小不变),chunked prefill 只是把 prefill 成本分摊到多步,总计算量不增加。吞吐轻微上升(+2-4%)属于测量波动。

设计取舍小结

取舍点 选择 理由
同时允许几个 PREFILLING 请求 仅 1 个 简化调度逻辑;多个同时 prefill 需要多个中间 DynamicCache,显存压力增加
中间 DynamicCache 存 GPU 还是 CPU CPU 内存 中间状态不需要立即用于 GPU 计算;避免占用额外 GPU 显存
chunk 完成后立即 decode 还是下一步 立即(当步参与 decode_batch) 减少延迟,prefill 和 decode 在同一步执行
PREFILLING 请求能否被抢占 中间 DynamicCache 重置代价高,简化实现;当前 Phase 不支持

和 vLLM 的对比视角

vLLM 的 chunked prefill(max_num_batched_tokens 参数)允许同时有多个请求处于 prefilling 状态,并把 prefill token 和 decode token 打包进同一个 batch forward,共享 GPU 计算。mini-infer 的实现更简单:每步只有一个 PREFILLING 请求,prefill 和 decode 是分开的两次 forward(一次 prefill_chunk,一次 decode_batch)。

这意味着 mini-infer 版本每步有两次模型 forward,而 vLLM 可以合并为一次。吞吐上有差距,但 ITL 控制效果是等价的——对于理解机制已经足够。

总结

Chunked Prefill 本质上是把一个不可分割的"大块 GPU 占用"拆成多个小块,让 decode 请求在间隙中运行。实现上最重要的是:

  1. 状态机扩展:PREFILLING 是独立于 WAITING 和 RUNNING 的第三种状态,调度器要正确处理这三个队列的转换
  2. DynamicCache 累积:每个 chunk 的 KV 通过 past_key_values 参数传递,最后一个 chunk 完成后一次性写入 block tensor
  3. Benchmark 口径:ITL spike 的测量需要全程连续记录时间戳,包含跨 warmup 和主循环阶段的最大间隔;仅统计长请求到达后的间隔会低估真实 spike

实测(Qwen2.5-7B,1024-token 长 prompt):chunk=256 时短请求 ITL spike 从 138ms 降至 59ms(−57%),总吞吐无回归。


延伸阅读


系列导航