这篇文章解决什么问题

完成环境准备之后,下一步不是马上去写复杂算子,而是先建立对 OpenAI Triton 的最小认知框架。

很多人第一次看 Triton 代码时,最容易出现两个误解:

  • 误以为 Triton kernel 和普通 Python 函数只是“语法长得不一样”。
  • 误以为只要把 for 循环改成 tl.arange,就算理解了并行执行。

这两种理解都不够。Triton 代码真正需要先想清楚的,不是某一行语法,而是下面四个问题:

  • 一个 kernel 会被拆成多少个执行单元。
  • 每个执行单元负责哪一块数据。
  • 这块数据如何映射到显存地址。
  • 边界位置如何避免越界访问。

本文会围绕 programgridtilemask 这四个概念展开,并用一个向量加法 kernel 把它们串起来。目标不是覆盖全部 Triton 语法,而是让后续看到任意一个 Triton kernel 时,知道应该先从哪里读起。

不要把 Triton kernel 当成串行函数

先给出一个更接近事实的理解:

Triton kernel 不是“在 GPU 上执行一次的函数”,而是一份会被复制成很多个并行执行单元的模板。

这句话里最关键的词是“很多个并行执行单元”。在 Triton 里,这个执行单元通常叫 program。每个 program 处理输入中的一块数据,所有 program 共同完成整个张量的计算。

如果把一个长度为 N 的一维向量切成若干块,执行过程可以抽象成下面这样:

1
2
3
4
5
6
输入数据: [ 0..1023 | 1024..2047 | 2048..3071 | ... ]
pid=0 pid=1 pid=2

每个 pid 对应一个独立 program
每个 program 处理一段连续数据
所有 program 并行执行

这就是理解 Triton kernel 的起点。只要这个视角没有建立起来,后面看到 program_idarangemask 时就很容易停留在语法层面。

四个核心概念

1. kernel

@triton.jit 修饰的函数就是 Triton kernel。它描述的是“每个 program 应该怎样处理自己负责的数据块”。

kernel 里写的是局部计算逻辑,不是整个输入的全局遍历逻辑。全局遍历的展开,是通过多个 program 并行完成的。

2. program

program 可以理解为一次独立的并行工作单元。它会通过 tl.program_id(axis=...) 获得自己的编号,然后根据这个编号计算自己该处理哪一段数据。

如果只处理一维向量,通常只关心 axis=0。如果处理矩阵或更复杂的 tile 布局,就可能同时使用多个维度。

3. grid

grid 决定一共要启动多少个 program。这个信息不是写死在 kernel 函数体里的,而是由 host 侧在启动 kernel 时提供。

因此,kernel 负责描述“单个 program 怎么做”,grid 负责描述“总共要启动多少个 program”。

4. tile

tile 是单个 program 负责处理的数据块。对于最简单的一维向量加法,它可以是一段长度为 BLOCK_SIZE 的连续元素。

在更复杂的场景里,tile 也可以是矩阵中的一个子块。例如 matmul 场景中,一个 program 可能负责输出矩阵中的一个二维块。

5. mask

mask 解决的是边界问题。因为输入长度通常不是 BLOCK_SIZE 的整数倍,最后一个 program 很可能只处理半满的一块数据。

如果没有 mask,最后这块就会读写到非法地址。对 Triton 而言,mask 不是附加细节,而是最基础的安全机制之一。

先记住一条读代码顺序

看到一个新的 Triton kernel,优先按下面顺序读,而不是从上到下机械阅读:

  1. 先看 program_id 是怎么取的。
  2. 再看 offsets 是怎么从 pid 推出来的。
  3. 再看 load / store 的地址表达式。
  4. 最后再看 mask 保护了哪些边界。

这个顺序比背语法表更有用。因为大多数 Triton kernel 的本质,都是“某个 program 如何根据自己的编号定位一块数据,然后在这块数据上完成局部计算”。

第一个完整例子:向量加法

向量加法不是最有代表性的高性能算子,但它非常适合作为第一个例子,因为它几乎不包含额外数学复杂度,读者可以把注意力集中在执行模型本身。

下面是一份最小可运行代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import triton
import triton.language as tl


@triton.jit
def add_kernel(
x_ptr,
y_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
tl.store(out_ptr + offsets, x + y, mask=mask)


def triton_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n = x.numel()
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, out, n, BLOCK_SIZE=1024)
return out


if __name__ == "__main__":
n = 1 << 20
x = torch.rand(n, device="cuda", dtype=torch.float32)
y = torch.rand(n, device="cuda", dtype=torch.float32)

out = triton_add(x, y)
assert torch.allclose(out, x + y, atol=1e-5)
print("vector add passed")

这段代码的价值不在于“完成了加法”,而在于它把 Triton 最核心的四个概念完整串起来了。

逐行理解这个 kernel

pid = tl.program_id(axis=0)

这一行拿到当前 program 在第 0 维 grid 中的编号。可以把它理解为“我是哪一个执行单元”。

如果 grid 是一维的,那么 pid 通常就是 0, 1, 2, 3, ... 这样递增。

block_start = pid * BLOCK_SIZE

有了 pid 之后,就可以把“我是第几个 program”转换成“我负责的数据从哪里开始”。

如果 BLOCK_SIZE=1024

  • pid=0 时,负责 [0, 1023]
  • pid=1 时,负责 [1024, 2047]
  • pid=2 时,负责 [2048, 3071]

这一步本质上就是把执行编号映射为数据分块起点。

offsets = block_start + tl.arange(0, BLOCK_SIZE)

这一步生成当前 tile 内所有元素的全局下标。

如果当前 block_start=2048,那么 offsets 就会是:

1
[2048, 2049, 2050, ..., 3071]

因此,offsets 才是真正把局部 tile 和全局显存地址联系起来的核心变量。

mask = offsets < n_elements

这一步负责处理最后一个 tile 的边界。

假设总长度是 2500,而 BLOCK_SIZE=1024,那么一共需要 3 个 program

  • 第一个处理 0..1023
  • 第二个处理 1024..2047
  • 第三个理论上会处理 2048..3071

但真实有效范围只有 2048..2499。如果没有 mask,后面超出的部分就会越界。mask 的作用就是告诉 Triton:只有条件为真的那些位置才允许参与读写。

tl.load(...)tl.store(...)

tl.load(x_ptr + offsets, mask=mask) 可以理解为“从一组地址批量读取一块数据”。tl.store(...) 则是对应的批量写回。

这里需要注意,Triton 的思考方式不是“取单个元素”,而是“围绕一组 offsets 成批处理一块 tile”。这也是它和普通 Python 数组访问差异最大的地方之一。

host 侧为什么还要写一个 wrapper

很多初学者会只盯着 kernel 函数体看,忽略 host 侧的包装函数。但在 Triton 中,host 侧和 kernel 侧是分工明确的。

在上面的例子里,host 侧承担三件事:

  • 分配输出张量。
  • 读取输入规模,例如 n = x.numel()
  • 根据输入规模和 BLOCK_SIZE 推导 grid。

这一点很重要。kernel 只负责单个 program 的局部逻辑,而“需要多少个 program”这个问题属于 host 侧。

grid 到底在做什么

这一行是整个例子里最值得停下来理解的一句:

1
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)

它的意思是:

  • 如果每个 program 处理 BLOCK_SIZE 个元素,
  • 那么总共需要 ceil(n / BLOCK_SIZE)program

例如:

  • n=1024,只需要 1 个 program
  • n=1025,需要 2 个 program
  • n=4096,需要 4 个 program

triton.cdiv 的作用就是向上整除。它经常出现在一维分块和二维 tiling 场景里,是非常常见的工具函数。

BLOCK_SIZE 为什么写成 tl.constexpr

BLOCK_SIZE 不是普通运行时参数,而是编译期常量。把它声明为 tl.constexpr,意味着 Triton 在生成 kernel 时会把它作为静态信息来处理。

这样做有两个直接结果:

  • 编译器可以围绕这个固定块大小做优化。
  • host 侧可以根据不同 BLOCK_SIZE 选择不同 grid 或不同 launch 方式。

在后续更复杂的 kernel 中,BLOCK_SIZEBLOCK_MBLOCK_NBLOCK_K 这类参数都会直接影响 kernel 形态,因此它们通常会以 tl.constexpr 的形式出现。

读 Triton kernel 时的三个固定问题

看到任何一个陌生 Triton kernel,可以先问自己三个问题:

当前 program 是谁

也就是先看 tl.program_id(...)。如果连这一点都没有弄清楚,后面的 offsets 很容易看成无意义的索引运算。

当前 program 负责哪块数据

也就是看 pid 如何映射为 block_start 或二维 tile 坐标。这一步决定了 kernel 的分块策略。

这块数据怎样映射到真实地址

也就是看 offsets 如何与指针相加,以及 mask 在保护什么边界。这里往往直接决定了访存模式是否合理。

这三个问题并不只适用于向量加法。后面看 softmax、RMSNorm、matmul、attention 路径时,依然是同一套读法。

这一节故意没有讲什么

为了把第一层认知压实,这一节刻意没有展开以下内容:

  • 共享内存和寄存器占用的深入分析。
  • 多维 grid 的更复杂映射。
  • autotune 与不同 block 配置的选择。
  • 更接近真实推理热点的 softmax、norm 或 matmul。

这些内容都会在后续文章继续展开。如果在这一节就同时引入性能分析和复杂算子,反而容易把最核心的执行模型冲淡。

常见误区

误区一:把 program 直接等同于 CUDA 里的某个固定层级

OpenAI Triton 借鉴了 GPU 编程模型,但它的抽象目标不是让用户逐项复刻 CUDA 的线程层级,而是让用户围绕 tile 和数据块表达 kernel。学习时如果一直试图做一一映射,通常会比直接理解 Triton 自身抽象更吃力。

误区二:只背 API,不看地址映射

如果一个例子只记住了 tl.loadtl.storetl.arange 这些名字,而没有真正看懂 offsets 是怎么来的,那么对 kernel 的理解仍然停留在表面。

误区三:把边界保护当成附加细节

在 Triton 里,mask 不是可选装饰。只要输入规模和 tile 大小之间存在不整除关系,mask 就是基本组成部分。

结论

OpenAI Triton 的入门关键,不是先背完所有语法,而是建立一套稳定的读取顺序:

  • 先看 program 如何编号。
  • 再看 tile 如何切分。
  • 再看 offsets 如何映射到地址。
  • 最后看 mask 如何保护边界。

只要这条主线建立起来,后面看到更复杂的 kernel,就不再只是“看见了一堆语法”,而是能把它还原成一套明确的数据分块与执行逻辑。

下一篇会继续往前走,把这一层执行模型放进性能分析框架里,讨论为什么很多 Triton kernel 的瓶颈不是算不动,而是搬不动。

系列导航