Triton:05 推理热点算子与 RMSNorm
这篇文章解决什么问题
写到这里,OpenAI Triton 这条线已经有了三个基础支点:
- 知道 kernel 是怎样按
program和tile展开的。 - 知道
memory-bound与compute-bound的差异。 - 知道 matmul 和
autotune为什么代表另一类优化问题。
接下来的自然问题是:在真实推理链路里,什么样的算子值得单独用 Triton 接管。
这不是一个只靠 FLOPs 就能回答的问题。很多初学者会默认把注意力全部放在 matmul 上,因为它的计算量最大。但在实际推理系统里,高频热点不只来自 FLOPs 最高的算子,还来自那些:
- 每层都会重复出现;
- 访存占比高;
- 位于关键前向路径上;
- 可以通过局部重写显著减少中间开销的算子。
这篇文章的目标就是建立这套判断框架,并用 RMSNorm 作为一个更贴近真实 LLM 推理路径的例子。
哪些算子值得重点关注
如果只从“总计算量”看,matmul 的确通常是最重的算子之一。但“值得单独优化”并不等价于“FLOPs 最大”。
从推理工程视角看,更有用的判断方式通常是这三个维度:
- 是否高频出现。
- 是否对访存敏感。
- 是否位于关键路径。
把这三条合在一起,一个更实用的判断可以写成:
值得单独接管的算子,通常是高频出现、带宽敏感、又位于关键前向路径上的那一类。
沿着这条逻辑去看推理链路,常见热点大致可以分成下面几类。
matmul / GEMM
它仍然是推理链路中的核心算子,尤其出现在:
- 线性层;
- QKV 投影;
- attention 中的矩阵乘法;
- 前馈网络中的大矩阵计算。
这类算子的特点通常是算术强度高,更容易接近 compute-bound。它们非常重要,但优化手段往往也更复杂,通常依赖成熟库、硬件特性或更深入的 tile 设计。
softmax
softmax 通常出现在 attention score 归一化或 logits 归一化路径中。它的典型特征是:
- 有 reduction;
- 中间值需要数值稳定处理;
- 更偏
memory-bound; - 非常适合讨论
fusion。
norm 类算子
这里主要指 LayerNorm、RMSNorm 等归一化算子。它们的共同特点是:
- 几乎每层都会出现;
- 单次计算不复杂;
- 更敏感于访存与中间结果处理;
- 很适合通过局部 kernel 重写降低开销。
纯 elementwise 算子
例如:
- 激活函数;
- 残差加法;
- 一些逐元素缩放或偏移。
这类算子单独看往往不重,但如果紧挨着其他算子出现,就常常适合与相邻步骤一起融合。
attention 路径本身
attention 真正的热点并不只是其中某一个单点,而是一整条依赖链:
Q @ K^T- mask
- softmax
scores @ V
如果这条链被拆成很多个独立 kernel,中间张量就会频繁进出全局内存。也正因为如此,attention 更适合被看成“路径级优化问题”,而不只是单个算子问题。
一个更稳妥的判断框架
把前面的经验压缩一下,可以得到一套更适合实际工程的筛选方法。
第一,看频率
一个算子如果只在局部偶尔出现,即使单次不慢,也未必值得优先接管。相反,如果它在每一层、每个 token 或每次前向中都会反复执行,那么累计成本就会迅速放大。
第二,看数据流
如果一个算子本身算术操作不重,但需要频繁读写大块数据,或者依赖中间结果反复落回显存,那么它往往更值得从访存角度分析,而不是只看数学表达式。
第三,看路径位置
处在关键路径上的算子更值得优先优化。关键路径的含义不是“数学上看起来重要”,而是“它会直接影响主前向过程的整体时延与吞吐”。
沿着这三个维度去看,RMSNorm 是一个非常合适的例子。
为什么是 RMSNorm
RMSNorm 之所以值得单独拿出来讲,不是因为它比 matmul 更重,而是因为它同时命中了上面的三条判断标准。
高频
在主流 LLM 中,RMSNorm 通常会在每个 Transformer block 中出现 1 到 2 次。层数一旦变多,它的累计调用次数会非常可观。
带宽敏感
RMSNorm 并不依赖复杂的高算术强度操作。它主要做的是:
- 读取一整行 hidden state;
- 计算平方和;
- 求均方根;
- 做归一化;
- 再乘以权重。
这类流程更像一个带 reduction 的内存密集型算子,而不是一个以高 FLOPs 为主的计算密集型算子。
位于关键路径
RMSNorm 不在旁支上,它就位于主前向路径里。也就是说,哪怕它单次开销没有 matmul 那么显眼,只要出现频率高、位置关键,整体影响就不能忽略。
RMSNorm 在做什么
RMSNorm 的核心公式是:
1 | y_i = x_i / sqrt(mean(x^2) + eps) * w_i |
和 LayerNorm 相比,它省去了均值中心化,只保留基于均方根的归一化。这种结构在多种 LLM 中被广泛采用。
从 kernel 视角看,RMSNorm 的关键点有两个:
- 它需要对整行做 reduction。
- reduction 之后还要立刻继续做逐元素变换。
这意味着它非常适合复用 softmax 里已经建立过的那套思路:尽量把整行读取、reduction 和最终输出写回放在同一个 kernel 中完成。
一个教学型 RMSNorm kernel
下面这份代码不追求覆盖生产级优化细节,它的目标是把结构讲清楚。
1 | import torch |
这段代码最重要的地方,不是数学公式本身,而是它沿用了前面两篇已经出现过的结构模式:
- 一个
program负责一整行。 - 整行数据一次读入。
- reduction 在局部上下文中完成。
- 最终只写回一次输出。
对 norm 类算子来说,这种“整行装入、局部 reduction、单次写回”的结构几乎就是第一层优化直觉。
这里为什么要把整行交给一个 program
原因和 softmax 很接近。RMSNorm 的 reduction 目标是整行 hidden dimension。如果把同一行拆给多个 program,就会立即引入跨 program 归约问题,结构复杂度会明显上升。
在教学型实现里,让一个 program 处理一整行,虽然会把 BLOCK_SIZE 与 n_cols 绑定,但好处是:
- 模型清晰;
- mask 边界处理直接;
- reduction 逻辑集中;
- 更容易和 softmax 做类比。
这也是为什么很多入门阶段的 row-wise 算子都适合从这种设计开始理解。
other=0.0 和 softmax 中的 -inf 为什么不同
这里有一个很容易忽略但很重要的细节。
在 RMSNorm 里,被 mask 掉的无效位置会被填成 0.0。这样做是正确的,因为:
- 它们不会影响平方和;
- 它们不会给均方根贡献额外值;
- 最终也不会被写回输出。
而在 softmax 里,填充值通常是 -inf,因为 softmax 更关心 max 和 exp 的数值行为。
两者看起来都在做“无效位置填充”,但填充值的选择并不是固定模板,而是要服务于该算子的数学结构。
从 RMSNorm 回头看前几篇
到这里可以把前面几篇 OpenAI Triton 文章串起来看。
Part 02 提供的是执行模型
也就是:
programgridtilemask
没有这一层,后面看到 row-wise norm kernel 只会觉得代码“很密”,但看不出结构。
Part 03 提供的是性能判断
也就是:
- 这个算子更偏
memory-bound - 优化重点通常不是增加算术,而是减少全局读写
- benchmark 需要结合数据流来解释
没有这一层,很容易把 RMSNorm 误判成“不复杂,所以不重要”的算子。
Part 04 提供的是另一种性能范式
matmul 代表的是高算术强度、强 tile 复用和 autotune。而 RMSNorm 则恰好提供了一个对照:不是所有热点都长得像 matmul。
这也是 OpenAI Triton 线索里很关键的一点:不同热点算子,优化逻辑并不相同。
attention 为什么更像一条链路
讲完 RMSNorm 之后,可以顺着这个视角再看 attention。
attention 的关键问题往往不在某一个局部点,而在整条路径的中间结果如何流动:
Q @ K^T- mask
- softmax
scores @ V
如果这条链被拆成多个独立 kernel,中间张量就会频繁写回显存、再从显存读出。也正因为如此,像 FlashAttention 这类实现的价值,不是“单独把某个 softmax 写得更快”,而是尽量在更长的路径上减少中间内存往返。
所以,理解 RMSNorm 这样的行级热点算子,其实也是在为理解更复杂的路径级优化做准备。
什么样的算子暂时不需要第一时间手写 Triton
为了避免过度优化,也需要反过来说清楚:并不是推理链路里的每个算子都值得立刻手写 Triton。
优先级较低的情况通常包括:
- 出现频率不高;
- 输入规模较小,累计收益有限;
- 已经被成熟后端很好覆盖;
- 优化之后很难减少关键数据流;
- 不处于主前向时延的关键路径。
这个判断很重要。否则很容易陷入“能写 Triton kernel 的都想接管”的误区。
常见误区
误区一:只盯着 FLOPs 最高的算子
FLOPs 很重要,但不是唯一标准。像 RMSNorm 这种高频、带宽敏感、位于关键路径的算子,即使单次计算量不大,也可能是值得优化的对象。
误区二:把 norm 类算子当成“没什么可讲的预处理”
在真实模型里,norm 类算子往往处在高频路径上。只因为它公式比 matmul 短,就忽略它的累计成本,这是很常见的判断失误。
误区三:把 attention 只看成一个 softmax
attention 的瓶颈经常来自整条路径的数据流,而不是单个 softmax 节点本身。只盯着一个节点,通常看不见中间张量读写带来的真正代价。
结论
从推理工程视角看,值得单独用 OpenAI Triton 接管的算子,通常不是简单按 FLOPs 排序出来的,而是要同时看:
- 高频程度;
- 访存敏感性;
- 所处路径位置。
RMSNorm 是一个典型例子。它不是推理链路里最显眼的算子,但它高频、带宽敏感,又处在关键前向路径上,因此非常适合用来建立“热点算子筛选”这层认知。
到这里,OpenAI Triton 这条主线已经形成了一个完整闭环:
- 从执行模型入门;
- 到性能判断;
- 到 matmul 与
autotune; - 再到推理路径中的热点算子筛选。
下一步就可以切到 NVIDIA Triton Inference Server 这条线,开始讨论怎样把模型组织成可调用、可调优的在线服务。
