mini-infer系统实战-03-向量化 KV Gather:为什么 batch 吞吐能从 49% 拉到 88%
mini-infer系统实战-03-向量化 KV Gather:为什么 batch 吞吐能从 49% 拉到 88%
这篇文章是 mini-infer 项目的第三篇技术复盘。Phase 2 跑通了 Paged KV Cache + Batch Decode,但 batch=8 throughput 只有 HF baseline 的 49.1%。Phase 3 的目标是找到根本瓶颈并消除它。实测结果:batch=8 throughput 从 201 tok/s 涨到 361 tok/s,达到 HF baseline 的 88.4%。
一、Phase 2 的残余问题
Phase 2 的 gather_batch_kv() 是这样的:
1 | # Phase 2:嵌套 Python 循环 |
三层循环:num_layers × batch_size × seq_len。以 Qwen2.5-7B(28 层)、batch=8、seq_len=128 为例,这是 28 × 8 × 128 = 28672 次 Python 迭代,每次都是一次小的 GPU tensor 索引操作。
每次小索引就是一次独立的 CUDA kernel 调用,或者至少是一次 Python 层面的同步点。 Python 解释器的开销不在计算上,在调度上——每次 .item()、每次标量索引,CPU 和 GPU 之间都要协调一次。
这解释了 Phase 2 的数据:batch=1 时没有 gather(单请求无需聚合),throughput 88%;batch=8 时 28672 次 Python 循环,throughput 跌到 49%。瓶颈不在 GPU,在 CPU 的 Python 解释器。
二、向量化方案
核心思路:把 3 层嵌套循环变成一次 PyTorch advanced indexing。
2.1 把 BlockTable 变成 Tensor
1 | # [batch, max_num_blocks] |
这里有一个 Python 循环,但它只遍历 batch(8 次),而不是 seq_len(128 次)。后续不再需要访问 self._block_tables 字典。
2.2 计算每个输出位置的物理地址
每个请求长度不同,gather 后需要左填充对齐到 max_seq_len。左填充的含义是:对于 seq_len=3、max_seq_len=6 的请求,output position 0-2 是填充(零),position 3-5 才是真实 token。
1 | seq_lens_t = torch.tensor(seq_lens, dtype=torch.long, device=device).unsqueeze(1) # [batch, 1] |
然后推算物理块号和块内 slot:
1 | block_indices = token_pos_clamped // self.block_size # [batch, max_seq_len] |
2.3 一次 advanced indexing 完成 gather
1 | valid_mask_f = valid_mask.unsqueeze(-1).unsqueeze(-1).to(dtype=cache_dtype) # [batch, max_seq_len, 1, 1] |
仍然有一个 for l in range(num_layers) 循环,但这层只有 28 次(固定的层数),不随 batch 或 seq_len 增长。每层内部是一次 k_cache[l][phys_blocks, slot_indices],这是一次 PyTorch advanced indexing,最终对应一次 CUDA kernel。
相比 Phase 2 的 28 × 8 × 128 = 28672 次 Python 循环 → Phase 3 的 28 次 Python 循环,量级差了 1000 倍。
2.4 填充区为什么不会出错
token_pos_clamped 对填充位 clamp 到 0,这意味着填充位会去 gather block 0 的 slot 0——也就是某个请求的第一个 token 的 KV(或者是未初始化的零值块)。这个值本身是错的,但 valid_mask_f = 0 会把它乘零,最终输出为零。所以填充位的 gather 目标是什么并不重要,只要乘以 mask 就能正确置零。
这个细节在测试里验证过:
1 | # req_b seq_len=3, max_seq_len=6:前 3 位应为零 |
三、DynamicCache:警告出在哪里
Phase 2 还有一个问题:每次运行都会打印:
1 | UserWarning: `past_key_values` must be a `DynamicCache` instance, using a tuple is deprecated... |
自然的判断是在 decode_batch() 里——那里把 gathered KV 构造成 tuple 传给模型。于是先修了 decode_batch:
1 | # 改造后的 decode_batch |
跑起来,警告还在。
加上 -W error::UserWarning 让警告变成异常并打印 traceback:
1 | File ".../model_runner.py", line 121, in prefill |
警告出在 prefill(),不是 decode_batch()。
根因:transformers 4.43.4 里,当 model() 收到 past_key_values=None(即不传这个参数),内部会走一条旧路径,把 past_key_values 构造成 tuple 然后再发出弃用警告——不是调用方传了 tuple,是模型内部生成了 tuple 并警告自己。
修复:显式传一个空的 DynamicCache:
1 | out = self.model(input_ids=input_ids, past_key_values=DynamicCache(), use_cache=True) |
这样模型知道调用方期望 DynamicCache 格式,走新路径,不触发警告。两个字的修复,但不读源码找不到根因。
四、实测数据
环境:Ubuntu 24.04 + RTX 4090,Qwen2.5-7B-Instruct,float16,transformers 4.43.4,flash-attn 2.3.6。
标准 benchmark(max_new_tokens=128)
| batch | Phase 2 | Phase 3 | HF baseline |
|---|---|---|---|
| 1 | 49.4 tok/s | 53.7 tok/s | 56.2 tok/s |
| 4 | 135.2 tok/s | 194.2 tok/s | 210.5 tok/s |
| 8 | 201.0 tok/s | 361.3 tok/s | 408.9 tok/s |
Phase 3 / HF:batch=1 为 95.5%,batch=4 为 92.3%,batch=8 为 88.4%。
batch=1 改善有限(+8.7%),因为 batch=1 时没有 gather 压力,Phase 2 已接近最优。批次越大,向量化的收益越显著。
混合长度 benchmark
8 条请求:短 prompt(“一句话介绍 KV Cache”)× 4,长 prompt(“请详细解释……”)× 4,统一 max_new_tokens=256。
Throughput:356.0 tok/s,Peak Mem:17.00 GB。
注:当前 generate() 不支持 per-request max_new_tokens,短 prompt 不会因为答完提前释放 KV 块,实际以最大值 256 统一运行。混合 benchmark 的意义更多在于验证混合长度场景的稳定性,而不是完整体现 continuous batching 的调度优势。
五、剩余差距在哪里
Phase 3 与 HF baseline 仍有 ~10% 的差距(batch=8),主要来自两个地方:
1. gather 本身仍有每步 KV 复制
向量化消除了 Python 循环,但 gather 操作本身仍需把分散的 block 数据复制到连续 dense tensor,然后传给 HF 模型。HF 的 KV 是连续分配的,不需要这次复制。
彻底消除需要在 attention 计算内直接支持分散的 block 寻址——即 flash_attn 2.5+ 的 block_tables 参数(PagedAttention 的真正核心)。当前环境 flash_attn 2.3.6 不具备这个能力。升级到 2.5+ 是 Phase 4 的选项之一。
2. 预分配 block tensor pool
mini-infer 预分配 512 个 block(约 1.8 GB),HF 按需分配。这部分不影响 throughput,但使 Peak Mem 比 HF 高约 1 GB。
六、一些工程细节
valid_mask 的 dtype:valid_mask = token_positions >= 0 得到 bool tensor。直接 k_tokens * valid_mask_f 会触发 PyTorch 隐式类型转换。正确做法是 valid_mask.unsqueeze(-1).unsqueeze(-1).to(dtype=cache_dtype) 显式转换为 float16,避免隐式转换带来的精度和性能问题。
从 out.past_key_values 提取新 KV:DynamicCache forward 后,out.past_key_values.key_cache[l] 的 shape 是 [batch, num_kv_heads, max_seq_len+1, head_dim],最后一个位置([:, :, -1, :])才是本次新 token 的 KV。
及时释放大张量:decode_batch 里 gathered KV 是 [num_layers, batch, num_kv_heads, max_seq_len, head_dim] 量级的大 tensor。写回 block tensor 后立刻 del k_batch, v_batch, cache, out, k_new, v_new,避免在下一步采样时峰值显存叠加。
七、小结
Phase 3 做了两件事:
- 向量化 gather_batch_kv():3 层 Python 嵌套循环 → 28 次 Python + 每层 1 次 CUDA kernel,batch=8 throughput 从 Phase 2 的 49.1% HF 提升到 88.4% HF
- DynamicCache 迁移:消除 transformers 弃用警告,根因在 prefill 不在 decode,修复是传空实例
DynamicCache()
剩余约 12% 的差距来自 gather 本身的 KV 复制,解决路径是 flash_attn 2.5+ 或自写 Triton kernel。这留给 Phase 4。
从另一个角度看:Phase 2 的瓶颈实际上是 CPU Python 解释器,而不是 GPU 算力。一个看似"GPU 密集"的推理系统,完全可以被 Python 层面的循环拖垮。理解这一点,是做推理优化的基本前提。
