这篇文章解决什么问题

学会写一个 Triton kernel,只是 OpenAI Triton 的起点。真正进入算子优化之后,第一件事不是继续堆语法,而是判断瓶颈到底在哪里。

如果这一步判断错误,后续所有优化都容易偏离方向。例如:

  • 一个明显受显存带宽限制的算子,却试图通过增加算术操作去提速。
  • 一个本应通过减少中间读写来优化的算子,却只关注单条指令本身。
  • benchmark 结果看起来变快了,但没有建立对“为什么会快”的解释。

这一篇只做一件事:建立 OpenAI Triton 的第一层性能思维。核心问题有三个:

  • 什么叫 memory-bound,什么叫 compute-bound
  • 为什么 softmax 是理解 fusion 的合适例子。
  • benchmark 结果应该怎样读,才不至于只剩下一串数字。

先区分两类瓶颈

GPU kernel 的性能瓶颈,通常可以先粗分为两类:

  • memory-bound:瓶颈主要来自数据搬运速度,算术单元并没有被充分压满。
  • compute-bound:瓶颈主要来自计算能力本身,数据已经足够快地送到了执行单元。

这不是 Triton 特有概念,但 Triton kernel 优化时几乎绕不开这组判断。因为优化方向会完全不同:

  • memory-bound 算子,更关注读写次数、访存模式、算子融合和中间结果是否落回显存。
  • compute-bound 算子,更关注 tile 复用、流水化、Tensor Core 利用率和并行映射效率。

如果连这一层都没有先分清,就很容易出现“在错误方向上努力”的情况。

为什么向量加法通常是 memory-bound

先看一个最简单的例子:向量加法。

对于每个元素,向量加法通常只做一件很轻的计算:

1
out[i] = x[i] + y[i]

但它至少涉及三次全局内存访问:

  • x[i]
  • y[i]
  • out[i]

也就是说,它做的计算极少,搬的数据却不少。此时性能上限通常首先由显存带宽决定,而不是由加法指令本身决定。

这类算子有一个典型特征:即使 GPU 理论算力很高,实际也未必能把计算单元压满。因为问题根本不在“算不动”,而在“搬不动”。

为什么矩阵乘法更容易接近 compute-bound

与向量加法不同,矩阵乘法对单次数据加载的复用更高。

C = A @ B 为例,一个数据块从显存读入后,往往会参与多次乘加运算。数据复用一旦提升,单次内存访问对应的计算量就会上升,算术强度也会随之提高。

这时 kernel 更有机会接近 compute-bound,优化重点也会转向:

  • tile 设计是否合理;
  • 数据复用是否充分;
  • 执行单元是否被持续喂饱;
  • Tensor Core 或指令级并行是否发挥出来。

这也是为什么在 Triton 里,向量加法和 matmul 虽然都能作为示例,但它们代表的是两种完全不同的性能直觉。

不要把“快”理解成单一概念

在实际分析里,“快”至少可能指三件事:

  • 单次 kernel latency 降低;
  • 有效带宽接近理论带宽;
  • 相同输入规模下吞吐提升。

memory-bound 算子来说,“快”的核心常常不是让计算更复杂,而是让显存读写更少、更顺、更集中。

因此,在 Triton 场景里看 benchmark,不能只盯着最终耗时。至少还要问:

  • 这个算子主要在搬数据还是做计算。
  • 当前实现减少了哪些读写。
  • 测到的结果更接近带宽上限,还是仍有较大差距。

Softmax 为什么适合作为第一个性能案例

softmax 是理解 Triton 性能思维的一个合适入口,因为它同时具备三个特点:

  • 它在推理链路中高频出现。
  • 它包含 row-wise reduction,结构上比纯 elementwise 稍复杂。
  • 它通常更偏 memory-bound,因此非常适合讨论 fusion 的收益。

以一行 softmax 为例,公式是:

1
softmax(x_i) = exp(x_i - max(x)) / sum(exp(x - max(x)))

这个过程至少包含以下逻辑:

  • 找到整行最大值;
  • 做数值稳定处理;
  • 计算指数;
  • 求和;
  • 再做归一化。

如果把这些步骤拆成多个独立 kernel,中间结果就会反复写回显存,再从显存读出来。对于本来就偏带宽受限的算子来说,这种中间读写本身就是明显成本。

softmax 里真正昂贵的是什么

很多人第一次看 softmax,会自然把注意力放在 exp 上,觉得“指数运算看起来最贵”。这不是完全错,但在 GPU 推理场景里,更值得先看的往往是数据流。

一个拆分实现的 softmax,往往会经历类似过程:

1
2
3
读输入 -> 求 max -> 写中间结果
读中间结果 -> 求 exp -> 写中间结果
读中间结果 -> 求 sum / div -> 写输出

这里最大的问题不是某一条算术指令,而是同一批数据被反复在全局内存和执行单元之间搬运。

因此,对这种算子来说,fusion 的直接意义不是“让数学变简单”,而是“减少不必要的全局读写”。

fusion 在这里减少的到底是什么

以 fused softmax 为例,一个更紧凑的实现会把一整行数据读入后,在寄存器或更局部的执行上下文中完成:

  • max
  • shift
  • exp
  • sum
  • div

最后再把结果写回一次。

如果从数据流角度看,差异可以粗略表示为:

1
2
3
4
5
6
7
naive:
global memory -> kernel1 -> global memory
global memory -> kernel2 -> global memory
global memory -> kernel3 -> global memory

fused:
global memory -> kernel -> global memory

这里节省下来的,主要是中间结果反复落回全局内存的成本。对 memory-bound 算子来说,这往往比去微调几条算术表达式更重要。

一个最小的 fused softmax 示例

下面是一份最小可运行思路。它不追求覆盖所有边界情况,而是用来说明单行 softmax 如何在一个 program 内完成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import triton
import triton.language as tl


@triton.jit
def softmax_kernel(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
row_idx = tl.program_id(axis=0)
row_start = input_ptr + row_idx * input_row_stride

offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < n_cols

row = tl.load(row_start + offsets, mask=mask, other=-float("inf"))
row_shifted = row - tl.max(row, axis=0)
numerator = tl.exp(row_shifted)
denominator = tl.sum(numerator, axis=0)
output = numerator / denominator

out_start = output_ptr + row_idx * output_row_stride
tl.store(out_start + offsets, output, mask=mask)


def triton_softmax(x: torch.Tensor) -> torch.Tensor:
n_rows, n_cols = x.shape
block_size = triton.next_power_of_2(n_cols)
y = torch.empty_like(x)

softmax_kernel[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
BLOCK_SIZE=block_size,
)
return y

这段代码最值得注意的不是具体 API,而是两个结构性选择:

  • 一个 program 负责一整行。
  • 整行的 reduction 与归一化在同一个 kernel 中完成。

这两个选择合起来,才构成了后面性能分析的前提。

为什么这里是“一行一个 program”

softmax 的 reduction 是按行做的。也就是说,一行中的所有列共同决定:

  • 这一行的最大值;
  • 这一行的指数和;
  • 这一行每个元素的归一化结果。

因此,如果把一行拆给多个 program,就会立刻引入更复杂的跨 program 通信或中间同步问题。对这个入门版本来说,最直接的方式就是让一个 program 处理一整行。

这样做的代价是,BLOCK_SIZE 必须能够覆盖 n_cols。这会把列宽和单个 program 的资源占用绑定在一起。后面处理更宽的行时,就必须开始关注寄存器压力和更复杂的分块策略。

但在入门阶段,这样的设计正好能把核心逻辑看清楚。

other=-inf 和数值稳定性为什么关键

mask 只告诉 Triton 哪些位置有效,但被屏蔽的位置仍然需要一个填充值。对 softmax 而言,把无效位置填成 -inf 非常合适,因为:

  • max 时它不会被选中;
  • exp 后它会变成 0;
  • sum 时不会对有效值产生干扰。

此外,softmax 里还必须先减去整行最大值。这不是“经验技巧”,而是数值稳定性的基本要求。否则当输入较大时,exp 很容易出现溢出。

从性能角度看,数值稳定性和性能并不是对立关系。一个会溢出的 kernel,即使跑得快,也不具备工程价值。

怎样读 benchmark,才不至于只剩一串数字

假设现在测到一个 softmax kernel 的耗时是 0.5 ms,这个数字本身其实信息不够。至少还要同时问三件事。

第一,输入规模是什么

没有输入形状,单个耗时数字几乎不能比较。4096 x 4096 的 softmax 和 512 x 128 的 softmax,不应该放在同一语境下解读。

第二,数据流大致是多少

对一个 fused softmax,可以先用很粗的方式估算有效带宽:

1
有效带宽 = 读写总字节数 / 耗时

如果一个实现主要目标是减少全局读写,那么带宽视角就比单看毫秒数更有解释力。

第三,和谁比

一个 benchmark 至少需要一个明确基线。常见基线包括:

  • PyTorch 默认实现;
  • 一个更 naive 的拆分实现;
  • 同一实现下不同 BLOCK_SIZE 或不同输入形状的结果。

如果只报“优化后更快”,但没有基线,就很难判断收益究竟来自哪里。

一个更实用的 benchmark 解释框架

看到一个 Triton benchmark 结果时,可以用下面这套顺序解释:

  1. 先说输入形状和数据类型。
  2. 再说算子更偏 memory-bound 还是 compute-bound
  3. 再说当前优化减少了什么。
  4. 最后再说耗时和对比结果。

例如,对 fused softmax 来说,一个更完整的描述应该类似于:

这是一个行归一化 softmax,输入为 4096 x 4096float32。算子整体更偏 memory-bound。当前实现的主要收益来自把 maxexpsumdiv 合并进单个 kernel,减少中间结果写回全局内存,因此在相同输入下比拆分实现或默认实现更接近带宽上限。

这样的描述,比单纯报一句“快了 2 倍”有用得多。

这一节故意不展开的内容

这一篇只建立第一层性能直觉,因此没有展开更深入的话题:

  • num_warpsnum_stages 的调节逻辑;
  • 更复杂的 occupancy 分析;
  • matmul 与 Tensor Core 相关优化;
  • autotune 搜索空间;
  • 更接近真实推理路径的 RMSNorm、attention 或 matmul fusion。

这些内容要么会在后续文章中单独展开,要么已经超出“第一层性能判断”的范围。

常见误区

误区一:一看到 GPU 算子就默认追求 FLOPS

并不是所有 kernel 的核心目标都是把算力压满。对很多 elementwise、norm、softmax 类算子来说,更常见的限制是带宽和中间读写成本。

误区二:把 fusion 理解成“把代码写得更长”

fusion 的重点不是把多段逻辑机械拼到一起,而是减少不必要的全局内存往返。如果拼接后的实现没有减少关键数据流,性能收益未必成立。

误区三:只看耗时,不看解释

没有输入规模、没有基线、没有瓶颈判断的 benchmark,通常只能作为现象记录,不能直接支持工程决策。

结论

OpenAI Triton 的性能分析,第一步不是调参数,而是先判断算子更接近 memory-bound 还是 compute-bound

对 softmax 这类更偏带宽受限的算子来说,核心优化思路通常不是增加算术复杂度,而是减少中间结果的全局读写。fusion 之所以重要,根本原因也在这里。

因此,看一个 Triton kernel 是否“优化成功”,至少要能回答三件事:

  • 它的主要瓶颈是什么。
  • 当前实现减少了什么成本。
  • benchmark 数字为什么能够支持这个结论。

下一篇会继续沿着这条主线往前走,从基础性能判断进入更接近工程实践的内容,讨论 matmul、autotune 以及 Triton 与 PyTorch 的集成方式。

系列导航