Triton:04 Matmul、Autotune 与 PyTorch 集成
这篇文章解决什么问题
前两篇 OpenAI Triton 文章解决的是两个基础问题:
- kernel 是怎样按
program和tile展开的。 - 为什么一些算子更接近
memory-bound,以及fusion为什么有效。
继续往前走,下一步自然会遇到一个更有代表性的算子:matmul。它的重要性不只是因为矩阵乘法常见,而是因为它代表了另一类完全不同的性能直觉。
和向量加法、softmax 这类更偏带宽受限的算子相比,matmul 往往具备更高的数据复用空间和更高的算术强度。因此,讨论 matmul 时,重点会逐渐从“少搬几次数据”转向“怎样让读进来的数据多参与几次有效计算”。
这篇文章围绕三个问题展开:
- 为什么 matmul 是理解 Triton 进阶编程模型的合适入口。
autotune究竟在选择什么,而不是简单“自动调参”。- 手写 Triton kernel 与 PyTorch、
torch.compile的关系到底是什么。
为什么 matmul 和前面的例子不同
先把最核心的区别说清楚。
向量加法的执行模式是:
- 读一次
x - 读一次
y - 做一次加法
- 写一次
out
它对每次全局内存访问的计算利用极低,因此更接近 memory-bound。
matmul 的思路完全不同。对 C = A @ B 来说,一个从 A 或 B 中读进来的数据块,并不是只参与一次计算,而是会在一个 tile 内被反复用于乘加累积。只要分块方式合理,同一批加载进来的数据就能被复用多次。
这意味着,matmul 的优化重点会转向:
- tile 如何切分;
- K 维如何分段推进;
- 累加器如何组织;
- 哪组 block 参数更适合当前矩阵形状。
也正因为如此,matmul 非常适合作为理解 autotune 的第一个真实场景。
先建立一个 tile 视角
以最常见的二维矩阵乘法为例:
1 | C[M, N] = A[M, K] @ B[K, N] |
在 Triton 里,一个常见思路是让单个 program 负责 C 中的一个子块,而不是只负责单个元素。
可以把它理解成下面这个过程:
1 | 一个 program 负责输出矩阵中的一个 BLOCK_M x BLOCK_N 子块 |
因此,matmul kernel 的关键不再只是“当前 program 是谁”,还包括:
- 当前
program负责输出矩阵的哪个 tile; - 它沿 K 维每次取多厚的切片;
- 这组 tile 参数是否能在当前 GPU 和当前矩阵形状下带来更高效率。
一个教学型 matmul kernel
下面这份代码只保留最核心结构,不追求达到生产级性能。它的目标是把 program、二维 tile 和 autotune 的关系讲清楚。
1 | import torch |
这段代码里,最重要的不是每个 API,而是三个结构性设计:
- 用一维
program_id解码出二维 tile 坐标。 - 沿 K 维循环推进,每次处理一个
BLOCK_K厚度切片。 - 把不同 tile 参数和执行参数交给
autotune选择。
这里的 program 到底负责什么
在向量加法里,一个 program 负责一段一维数据。到了 matmul,一个 program 的职责变成了:
- 负责输出矩阵中的一个
BLOCK_M x BLOCK_N子块; - 沿着 K 维逐段读取输入 tile;
- 把多个局部乘加结果累积到
acc中; - 最后把这个局部 tile 写回输出矩阵。
也就是说,当前 program 不再直接对应“一个元素”或“一段向量”,而是对应“输出矩阵中的一个二维块”。
这一步是从基础 Triton 编程模型走向更真实算子的关键转折。
为什么要有 acc
acc 是局部累加器。matmul 不会在每次 tl.dot(a, b) 后立刻把中间结果写回全局内存,而是先在更局部的上下文中不断累积,等所有 K 维切片都处理完,再统一写回。
这种设计有两个直接好处:
- 避免把每一轮 K 维局部结果都落回全局内存。
- 让单次加载进来的 tile 数据参与更多乘加计算。
从性能角度看,这正是 matmul 和前面 memory-bound 示例最大的差异之一。它更强调复用与累积,而不是只做一次轻量操作就立即写回。
autotune 到底在选什么
很多人第一次看到 @triton.autotune,会把它简单理解成“自动调参”。这不算错,但太粗。
更准确的说法是:
autotune在给定输入规模下,对多组 kernel 配置进行实测,然后缓存其中表现最好的那一组。
这里的“配置”通常包括:
BLOCK_MBLOCK_NBLOCK_Knum_warpsnum_stages
这些参数共同决定了:
- 一个
program处理多大的 tile; - K 维推进的粒度;
- 并行协作粒度;
- 软件流水深度;
- 寄存器和其他资源的占用方式。
因此,autotune 本质上不是魔法,而是把“凭经验猜参数”改造成“对候选配置做系统化 benchmark”。
为什么不同形状可能需要不同配置
假设有三种情况:
M很大而N较小;N很大而M较小;M、N、K都比较均衡;
这三种输入下,最合适的 tile 形状并不一定相同。更大的 BLOCK_M 可能让行方向复用更充分,但也可能带来更高资源压力;更大的 BLOCK_N 也是同理。
因此,原稿里给出的三组候选配置本质上是在覆盖不同的 tile 偏好:
64 x 64 x 32128 x 64 x 3264 x 128 x 32
这些配置不是“谁一定更先进”,而是在不同输入规模和不同硬件条件下各有适用区间。
num_warps 和 num_stages 不只是附属参数
很多入门文章只强调 BLOCK_M、BLOCK_N、BLOCK_K,而把 num_warps 和 num_stages 当成附带选项。实际上它们同样会明显影响性能。
num_warps
它决定一个 program 使用多少 warp 参与执行。warp 太少,可能无法把并行度拉起来;warp 太多,又可能导致资源占用上升,反而降低整体效率。
num_stages
它通常可以理解为软件流水深度,影响数据预取与计算重叠的程度。更高的 num_stages 可能带来更好的延迟隐藏,但代价往往是更大的资源压力。
这也是为什么真正做调优时,autotune 往往会把这两个参数与 block 参数一起纳入搜索空间。
教学型 matmul 为什么未必比 PyTorch 快
这一点需要明确说明。一个教学型 Triton matmul kernel,即使结构正确,也很可能明显慢于 PyTorch 直接调用的成熟实现。
原因并不神秘。PyTorch 在矩阵乘法场景里通常会调用高度优化的底层库,例如 cuBLAS。相比之下,教学型 Triton kernel 往往省略了大量真正影响峰值性能的工程细节,例如:
- 更复杂的数据布局与 swizzle;
- 更成熟的预取与流水设计;
- 更深的 autotune 搜索空间;
- 针对特定硬件的细化优化。
因此,这一节的目标不是“立刻写出一个超过 cuBLAS 的 matmul”,而是把 Triton 里 tile、局部累积和 autotune 的结构关系讲清楚。
Triton 和 PyTorch 的关系,不是二选一
很多人在接触 Triton 后,会自然问一个问题:既然 torch.compile 也会生成 Triton kernel,那手写 Triton 还有什么意义?
这个问题不能用一句话回答,但可以先明确两层关系。
第一层:PyTorch 会在部分场景下自动生成 Triton kernel
当使用 torch.compile 且后端为 inductor 时,PyTorch 在某些图优化路径里会生成 Triton kernel。这意味着:
- 不是所有 Triton kernel 都需要手写;
- 对标准算子组合,编译器有机会自动下沉并生成较优实现;
- 这条路径更偏“让常规 PyTorch 代码自动获得更好的执行计划”。
第二层:手写 Triton 解决的是编译器默认路径之外的问题
当你需要:
- 明确控制 tile 设计;
- 为特殊形状或特殊数据流编写定制 kernel;
- 对某个热点路径做更激进的局部优化;
手写 Triton 仍然有明确价值。它不是为了替代 PyTorch,而是为了在默认编译路径不足以满足目标时,提供更低层的可控性。
torch.compile 和手写 Triton 怎样协作
一个更实际的理解是:
torch.compile负责优化它能识别、能融合、能下沉的标准图。- 手写 Triton 负责你明确想接管的局部热点。
二者并不是互斥关系。一个完整的推理 pipeline 很可能同时包含:
- 一部分由 PyTorch 图编译自动优化的普通算子;
- 一部分由手写 Triton kernel 接管的热点路径;
- 其余仍由标准库或底层后端处理的执行部分。
因此,在工程里更值得问的问题不是“到底该用哪一个”,而是“当前这个热点,默认编译路径已经足够了吗”。
看待这篇 matmul 的正确方式
如果只把这一篇看成“写出了一个矩阵乘法 kernel”,价值会比较有限。更合理的理解方式是:
- 它把一维分块扩展成了二维 tile。
- 它第一次引入了沿 K 维迭代累积的结构。
- 它第一次引入了
autotune作为系统化选配置的机制。 - 它把 Triton kernel 与 PyTorch 编译路径的关系放回了同一张图里。
这几件事合起来,才是本篇真正的核心。
常见误区
误区一:把 matmul 的性能全部归因于 tl.dot
tl.dot 很重要,但它只是局部乘加原语。真正决定整体效率的,仍然是 tile 设计、K 维推进方式和配置选择。
误区二:把 autotune 当成最终答案
autotune 只能在给定候选配置中选一个当前更优的结果。候选空间本身设计得不好,最终结果也不会理想。
误区三:因为 PyTorch 更快,就认为教学型 Triton kernel 没有意义
教学型 kernel 的目标是建立结构认知,而不是直接替代成熟库。只要这个边界没有混淆,这类示例仍然非常有价值。
结论
matmul 是理解 Triton 进阶模型的关键案例,因为它把注意力从一维块处理推进到了二维 tile、局部累积和配置搜索。
这一篇最重要的结论不是“如何立刻写出最快的 matmul”,而是:
- 一个
program可以负责输出矩阵中的一个二维 tile; - K 维循环是 matmul kernel 的核心骨架;
autotune的本质是对候选配置做实测选择;- 手写 Triton 与 PyTorch 编译路径是协作关系,不是简单替代关系。
下一篇会继续往前走,把注意力从通用 matmul 转向推理路径里更常见的热点算子,讨论什么样的算子值得单独用 Triton 接管,以及判断依据是什么。

