Triton:02 OpenAI Triton 基础编程模型
这篇文章解决什么问题
完成环境准备之后,下一步不是马上去写复杂算子,而是先建立对 OpenAI Triton 的最小认知框架。
很多人第一次看 Triton 代码时,最容易出现两个误解:
- 误以为 Triton kernel 和普通 Python 函数只是“语法长得不一样”。
- 误以为只要把
for循环改成tl.arange,就算理解了并行执行。
这两种理解都不够。Triton 代码真正需要先想清楚的,不是某一行语法,而是下面四个问题:
- 一个 kernel 会被拆成多少个执行单元。
- 每个执行单元负责哪一块数据。
- 这块数据如何映射到显存地址。
- 边界位置如何避免越界访问。
本文会围绕 program、grid、tile、mask 这四个概念展开,并用一个向量加法 kernel 把它们串起来。目标不是覆盖全部 Triton 语法,而是让后续看到任意一个 Triton kernel 时,知道应该先从哪里读起。
不要把 Triton kernel 当成串行函数
先给出一个更接近事实的理解:
Triton kernel 不是“在 GPU 上执行一次的函数”,而是一份会被复制成很多个并行执行单元的模板。
这句话里最关键的词是“很多个并行执行单元”。在 Triton 里,这个执行单元通常叫 program。每个 program 处理输入中的一块数据,所有 program 共同完成整个张量的计算。
如果把一个长度为 N 的一维向量切成若干块,执行过程可以抽象成下面这样:
1 | 输入数据: [ 0..1023 | 1024..2047 | 2048..3071 | ... ] |
这就是理解 Triton kernel 的起点。只要这个视角没有建立起来,后面看到 program_id、arange、mask 时就很容易停留在语法层面。
四个核心概念
1. kernel
@triton.jit 修饰的函数就是 Triton kernel。它描述的是“每个 program 应该怎样处理自己负责的数据块”。
kernel 里写的是局部计算逻辑,不是整个输入的全局遍历逻辑。全局遍历的展开,是通过多个 program 并行完成的。
2. program
program 可以理解为一次独立的并行工作单元。它会通过 tl.program_id(axis=...) 获得自己的编号,然后根据这个编号计算自己该处理哪一段数据。
如果只处理一维向量,通常只关心 axis=0。如果处理矩阵或更复杂的 tile 布局,就可能同时使用多个维度。
3. grid
grid 决定一共要启动多少个 program。这个信息不是写死在 kernel 函数体里的,而是由 host 侧在启动 kernel 时提供。
因此,kernel 负责描述“单个 program 怎么做”,grid 负责描述“总共要启动多少个 program”。
4. tile
tile 是单个 program 负责处理的数据块。对于最简单的一维向量加法,它可以是一段长度为 BLOCK_SIZE 的连续元素。
在更复杂的场景里,tile 也可以是矩阵中的一个子块。例如 matmul 场景中,一个 program 可能负责输出矩阵中的一个二维块。
5. mask
mask 解决的是边界问题。因为输入长度通常不是 BLOCK_SIZE 的整数倍,最后一个 program 很可能只处理半满的一块数据。
如果没有 mask,最后这块就会读写到非法地址。对 Triton 而言,mask 不是附加细节,而是最基础的安全机制之一。
先记住一条读代码顺序
看到一个新的 Triton kernel,优先按下面顺序读,而不是从上到下机械阅读:
- 先看
program_id是怎么取的。 - 再看 offsets 是怎么从
pid推出来的。 - 再看
load/store的地址表达式。 - 最后再看
mask保护了哪些边界。
这个顺序比背语法表更有用。因为大多数 Triton kernel 的本质,都是“某个 program 如何根据自己的编号定位一块数据,然后在这块数据上完成局部计算”。
第一个完整例子:向量加法
向量加法不是最有代表性的高性能算子,但它非常适合作为第一个例子,因为它几乎不包含额外数学复杂度,读者可以把注意力集中在执行模型本身。
下面是一份最小可运行代码:
1 | import torch |
这段代码的价值不在于“完成了加法”,而在于它把 Triton 最核心的四个概念完整串起来了。
逐行理解这个 kernel
pid = tl.program_id(axis=0)
这一行拿到当前 program 在第 0 维 grid 中的编号。可以把它理解为“我是哪一个执行单元”。
如果 grid 是一维的,那么 pid 通常就是 0, 1, 2, 3, ... 这样递增。
block_start = pid * BLOCK_SIZE
有了 pid 之后,就可以把“我是第几个 program”转换成“我负责的数据从哪里开始”。
如果 BLOCK_SIZE=1024:
pid=0时,负责[0, 1023]pid=1时,负责[1024, 2047]pid=2时,负责[2048, 3071]
这一步本质上就是把执行编号映射为数据分块起点。
offsets = block_start + tl.arange(0, BLOCK_SIZE)
这一步生成当前 tile 内所有元素的全局下标。
如果当前 block_start=2048,那么 offsets 就会是:
1 | [2048, 2049, 2050, ..., 3071] |
因此,offsets 才是真正把局部 tile 和全局显存地址联系起来的核心变量。
mask = offsets < n_elements
这一步负责处理最后一个 tile 的边界。
假设总长度是 2500,而 BLOCK_SIZE=1024,那么一共需要 3 个 program:
- 第一个处理
0..1023 - 第二个处理
1024..2047 - 第三个理论上会处理
2048..3071
但真实有效范围只有 2048..2499。如果没有 mask,后面超出的部分就会越界。mask 的作用就是告诉 Triton:只有条件为真的那些位置才允许参与读写。
tl.load(...) 与 tl.store(...)
tl.load(x_ptr + offsets, mask=mask) 可以理解为“从一组地址批量读取一块数据”。tl.store(...) 则是对应的批量写回。
这里需要注意,Triton 的思考方式不是“取单个元素”,而是“围绕一组 offsets 成批处理一块 tile”。这也是它和普通 Python 数组访问差异最大的地方之一。
host 侧为什么还要写一个 wrapper
很多初学者会只盯着 kernel 函数体看,忽略 host 侧的包装函数。但在 Triton 中,host 侧和 kernel 侧是分工明确的。
在上面的例子里,host 侧承担三件事:
- 分配输出张量。
- 读取输入规模,例如
n = x.numel()。 - 根据输入规模和
BLOCK_SIZE推导 grid。
这一点很重要。kernel 只负责单个 program 的局部逻辑,而“需要多少个 program”这个问题属于 host 侧。
grid 到底在做什么
这一行是整个例子里最值得停下来理解的一句:
1 | grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) |
它的意思是:
- 如果每个
program处理BLOCK_SIZE个元素, - 那么总共需要
ceil(n / BLOCK_SIZE)个program。
例如:
n=1024,只需要1个 programn=1025,需要2个 programn=4096,需要4个 program
triton.cdiv 的作用就是向上整除。它经常出现在一维分块和二维 tiling 场景里,是非常常见的工具函数。
BLOCK_SIZE 为什么写成 tl.constexpr
BLOCK_SIZE 不是普通运行时参数,而是编译期常量。把它声明为 tl.constexpr,意味着 Triton 在生成 kernel 时会把它作为静态信息来处理。
这样做有两个直接结果:
- 编译器可以围绕这个固定块大小做优化。
- host 侧可以根据不同
BLOCK_SIZE选择不同 grid 或不同 launch 方式。
在后续更复杂的 kernel 中,BLOCK_SIZE、BLOCK_M、BLOCK_N、BLOCK_K 这类参数都会直接影响 kernel 形态,因此它们通常会以 tl.constexpr 的形式出现。
读 Triton kernel 时的三个固定问题
看到任何一个陌生 Triton kernel,可以先问自己三个问题:
当前 program 是谁
也就是先看 tl.program_id(...)。如果连这一点都没有弄清楚,后面的 offsets 很容易看成无意义的索引运算。
当前 program 负责哪块数据
也就是看 pid 如何映射为 block_start 或二维 tile 坐标。这一步决定了 kernel 的分块策略。
这块数据怎样映射到真实地址
也就是看 offsets 如何与指针相加,以及 mask 在保护什么边界。这里往往直接决定了访存模式是否合理。
这三个问题并不只适用于向量加法。后面看 softmax、RMSNorm、matmul、attention 路径时,依然是同一套读法。
这一节故意没有讲什么
为了把第一层认知压实,这一节刻意没有展开以下内容:
- 共享内存和寄存器占用的深入分析。
- 多维 grid 的更复杂映射。
- autotune 与不同 block 配置的选择。
- 更接近真实推理热点的 softmax、norm 或 matmul。
这些内容都会在后续文章继续展开。如果在这一节就同时引入性能分析和复杂算子,反而容易把最核心的执行模型冲淡。
常见误区
误区一:把 program 直接等同于 CUDA 里的某个固定层级
OpenAI Triton 借鉴了 GPU 编程模型,但它的抽象目标不是让用户逐项复刻 CUDA 的线程层级,而是让用户围绕 tile 和数据块表达 kernel。学习时如果一直试图做一一映射,通常会比直接理解 Triton 自身抽象更吃力。
误区二:只背 API,不看地址映射
如果一个例子只记住了 tl.load、tl.store、tl.arange 这些名字,而没有真正看懂 offsets 是怎么来的,那么对 kernel 的理解仍然停留在表面。
误区三:把边界保护当成附加细节
在 Triton 里,mask 不是可选装饰。只要输入规模和 tile 大小之间存在不整除关系,mask 就是基本组成部分。
结论
OpenAI Triton 的入门关键,不是先背完所有语法,而是建立一套稳定的读取顺序:
- 先看
program如何编号。 - 再看 tile 如何切分。
- 再看 offsets 如何映射到地址。
- 最后看 mask 如何保护边界。
只要这条主线建立起来,后面看到更复杂的 kernel,就不再只是“看见了一堆语法”,而是能把它还原成一套明确的数据分块与执行逻辑。
下一篇会继续往前走,把这一层执行模型放进性能分析框架里,讨论为什么很多 Triton kernel 的瓶颈不是算不动,而是搬不动。

