这篇文章解决什么问题

写到这里,OpenAI Triton 这条线已经有了三个基础支点:

  • 知道 kernel 是怎样按 programtile 展开的。
  • 知道 memory-boundcompute-bound 的差异。
  • 知道 matmul 和 autotune 为什么代表另一类优化问题。

接下来的自然问题是:在真实推理链路里,什么样的算子值得单独用 Triton 接管。

这不是一个只靠 FLOPs 就能回答的问题。很多初学者会默认把注意力全部放在 matmul 上,因为它的计算量最大。但在实际推理系统里,高频热点不只来自 FLOPs 最高的算子,还来自那些:

  • 每层都会重复出现;
  • 访存占比高;
  • 位于关键前向路径上;
  • 可以通过局部重写显著减少中间开销的算子。

这篇文章的目标就是建立这套判断框架,并用 RMSNorm 作为一个更贴近真实 LLM 推理路径的例子。

哪些算子值得重点关注

如果只从“总计算量”看,matmul 的确通常是最重的算子之一。但“值得单独优化”并不等价于“FLOPs 最大”。

从推理工程视角看,更有用的判断方式通常是这三个维度:

  • 是否高频出现。
  • 是否对访存敏感。
  • 是否位于关键路径。

把这三条合在一起,一个更实用的判断可以写成:

值得单独接管的算子,通常是高频出现、带宽敏感、又位于关键前向路径上的那一类。

沿着这条逻辑去看推理链路,常见热点大致可以分成下面几类。

matmul / GEMM

它仍然是推理链路中的核心算子,尤其出现在:

  • 线性层;
  • QKV 投影;
  • attention 中的矩阵乘法;
  • 前馈网络中的大矩阵计算。

这类算子的特点通常是算术强度高,更容易接近 compute-bound。它们非常重要,但优化手段往往也更复杂,通常依赖成熟库、硬件特性或更深入的 tile 设计。

softmax

softmax 通常出现在 attention score 归一化或 logits 归一化路径中。它的典型特征是:

  • 有 reduction;
  • 中间值需要数值稳定处理;
  • 更偏 memory-bound
  • 非常适合讨论 fusion

norm 类算子

这里主要指 LayerNormRMSNorm 等归一化算子。它们的共同特点是:

  • 几乎每层都会出现;
  • 单次计算不复杂;
  • 更敏感于访存与中间结果处理;
  • 很适合通过局部 kernel 重写降低开销。

纯 elementwise 算子

例如:

  • 激活函数;
  • 残差加法;
  • 一些逐元素缩放或偏移。

这类算子单独看往往不重,但如果紧挨着其他算子出现,就常常适合与相邻步骤一起融合。

attention 路径本身

attention 真正的热点并不只是其中某一个单点,而是一整条依赖链:

  • Q @ K^T
  • mask
  • softmax
  • scores @ V

如果这条链被拆成很多个独立 kernel,中间张量就会频繁进出全局内存。也正因为如此,attention 更适合被看成“路径级优化问题”,而不只是单个算子问题。

一个更稳妥的判断框架

把前面的经验压缩一下,可以得到一套更适合实际工程的筛选方法。

第一,看频率

一个算子如果只在局部偶尔出现,即使单次不慢,也未必值得优先接管。相反,如果它在每一层、每个 token 或每次前向中都会反复执行,那么累计成本就会迅速放大。

第二,看数据流

如果一个算子本身算术操作不重,但需要频繁读写大块数据,或者依赖中间结果反复落回显存,那么它往往更值得从访存角度分析,而不是只看数学表达式。

第三,看路径位置

处在关键路径上的算子更值得优先优化。关键路径的含义不是“数学上看起来重要”,而是“它会直接影响主前向过程的整体时延与吞吐”。

沿着这三个维度去看,RMSNorm 是一个非常合适的例子。

为什么是 RMSNorm

RMSNorm 之所以值得单独拿出来讲,不是因为它比 matmul 更重,而是因为它同时命中了上面的三条判断标准。

高频

在主流 LLM 中,RMSNorm 通常会在每个 Transformer block 中出现 1 到 2 次。层数一旦变多,它的累计调用次数会非常可观。

带宽敏感

RMSNorm 并不依赖复杂的高算术强度操作。它主要做的是:

  • 读取一整行 hidden state;
  • 计算平方和;
  • 求均方根;
  • 做归一化;
  • 再乘以权重。

这类流程更像一个带 reduction 的内存密集型算子,而不是一个以高 FLOPs 为主的计算密集型算子。

位于关键路径

RMSNorm 不在旁支上,它就位于主前向路径里。也就是说,哪怕它单次开销没有 matmul 那么显眼,只要出现频率高、位置关键,整体影响就不能忽略。

RMSNorm 在做什么

RMSNorm 的核心公式是:

1
y_i = x_i / sqrt(mean(x^2) + eps) * w_i

LayerNorm 相比,它省去了均值中心化,只保留基于均方根的归一化。这种结构在多种 LLM 中被广泛采用。

从 kernel 视角看,RMSNorm 的关键点有两个:

  • 它需要对整行做 reduction。
  • reduction 之后还要立刻继续做逐元素变换。

这意味着它非常适合复用 softmax 里已经建立过的那套思路:尽量把整行读取、reduction 和最终输出写回放在同一个 kernel 中完成。

一个教学型 RMSNorm kernel

下面这份代码不追求覆盖生产级优化细节,它的目标是把结构讲清楚。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import triton
import triton.language as tl


@triton.jit
def rmsnorm_kernel(
output_ptr,
input_ptr,
weight_ptr,
input_row_stride,
output_row_stride,
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
row_idx = tl.program_id(axis=0)

row_start = input_ptr + row_idx * input_row_stride
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < n_cols

x = tl.load(row_start + offsets, mask=mask, other=0.0)
x_sq_sum = tl.sum(x * x, axis=0)
rms = tl.sqrt(x_sq_sum / n_cols + eps)

w = tl.load(weight_ptr + offsets, mask=mask, other=1.0)
output = (x / rms) * w

out_start = output_ptr + row_idx * output_row_stride
tl.store(out_start + offsets, output, mask=mask)


def triton_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
n_rows, n_cols = x.shape
block_size = triton.next_power_of_2(n_cols)
y = torch.empty_like(x)

rmsnorm_kernel[(n_rows,)](
y,
x,
weight,
x.stride(0),
y.stride(0),
n_cols,
eps,
BLOCK_SIZE=block_size,
)
return y

这段代码最重要的地方,不是数学公式本身,而是它沿用了前面两篇已经出现过的结构模式:

  • 一个 program 负责一整行。
  • 整行数据一次读入。
  • reduction 在局部上下文中完成。
  • 最终只写回一次输出。

对 norm 类算子来说,这种“整行装入、局部 reduction、单次写回”的结构几乎就是第一层优化直觉。

这里为什么要把整行交给一个 program

原因和 softmax 很接近。RMSNorm 的 reduction 目标是整行 hidden dimension。如果把同一行拆给多个 program,就会立即引入跨 program 归约问题,结构复杂度会明显上升。

在教学型实现里,让一个 program 处理一整行,虽然会把 BLOCK_SIZEn_cols 绑定,但好处是:

  • 模型清晰;
  • mask 边界处理直接;
  • reduction 逻辑集中;
  • 更容易和 softmax 做类比。

这也是为什么很多入门阶段的 row-wise 算子都适合从这种设计开始理解。

other=0.0 和 softmax 中的 -inf 为什么不同

这里有一个很容易忽略但很重要的细节。

RMSNorm 里,被 mask 掉的无效位置会被填成 0.0。这样做是正确的,因为:

  • 它们不会影响平方和;
  • 它们不会给均方根贡献额外值;
  • 最终也不会被写回输出。

而在 softmax 里,填充值通常是 -inf,因为 softmax 更关心 maxexp 的数值行为。

两者看起来都在做“无效位置填充”,但填充值的选择并不是固定模板,而是要服务于该算子的数学结构。

RMSNorm 回头看前几篇

到这里可以把前面几篇 OpenAI Triton 文章串起来看。

Part 02 提供的是执行模型

也就是:

  • program
  • grid
  • tile
  • mask

没有这一层,后面看到 row-wise norm kernel 只会觉得代码“很密”,但看不出结构。

Part 03 提供的是性能判断

也就是:

  • 这个算子更偏 memory-bound
  • 优化重点通常不是增加算术,而是减少全局读写
  • benchmark 需要结合数据流来解释

没有这一层,很容易把 RMSNorm 误判成“不复杂,所以不重要”的算子。

Part 04 提供的是另一种性能范式

matmul 代表的是高算术强度、强 tile 复用和 autotune。而 RMSNorm 则恰好提供了一个对照:不是所有热点都长得像 matmul。

这也是 OpenAI Triton 线索里很关键的一点:不同热点算子,优化逻辑并不相同。

attention 为什么更像一条链路

讲完 RMSNorm 之后,可以顺着这个视角再看 attention。

attention 的关键问题往往不在某一个局部点,而在整条路径的中间结果如何流动:

  • Q @ K^T
  • mask
  • softmax
  • scores @ V

如果这条链被拆成多个独立 kernel,中间张量就会频繁写回显存、再从显存读出。也正因为如此,像 FlashAttention 这类实现的价值,不是“单独把某个 softmax 写得更快”,而是尽量在更长的路径上减少中间内存往返。

所以,理解 RMSNorm 这样的行级热点算子,其实也是在为理解更复杂的路径级优化做准备。

什么样的算子暂时不需要第一时间手写 Triton

为了避免过度优化,也需要反过来说清楚:并不是推理链路里的每个算子都值得立刻手写 Triton。

优先级较低的情况通常包括:

  • 出现频率不高;
  • 输入规模较小,累计收益有限;
  • 已经被成熟后端很好覆盖;
  • 优化之后很难减少关键数据流;
  • 不处于主前向时延的关键路径。

这个判断很重要。否则很容易陷入“能写 Triton kernel 的都想接管”的误区。

常见误区

误区一:只盯着 FLOPs 最高的算子

FLOPs 很重要,但不是唯一标准。像 RMSNorm 这种高频、带宽敏感、位于关键路径的算子,即使单次计算量不大,也可能是值得优化的对象。

误区二:把 norm 类算子当成“没什么可讲的预处理”

在真实模型里,norm 类算子往往处在高频路径上。只因为它公式比 matmul 短,就忽略它的累计成本,这是很常见的判断失误。

误区三:把 attention 只看成一个 softmax

attention 的瓶颈经常来自整条路径的数据流,而不是单个 softmax 节点本身。只盯着一个节点,通常看不见中间张量读写带来的真正代价。

结论

从推理工程视角看,值得单独用 OpenAI Triton 接管的算子,通常不是简单按 FLOPs 排序出来的,而是要同时看:

  • 高频程度;
  • 访存敏感性;
  • 所处路径位置。

RMSNorm 是一个典型例子。它不是推理链路里最显眼的算子,但它高频、带宽敏感,又处在关键前向路径上,因此非常适合用来建立“热点算子筛选”这层认知。

到这里,OpenAI Triton 这条主线已经形成了一个完整闭环:

  • 从执行模型入门;
  • 到性能判断;
  • 到 matmul 与 autotune
  • 再到推理路径中的热点算子筛选。

下一步就可以切到 NVIDIA Triton Inference Server 这条线,开始讨论怎样把模型组织成可调用、可调优的在线服务。

系列导航