这篇文章解决什么问题

前两篇 OpenAI Triton 文章解决的是两个基础问题:

  • kernel 是怎样按 programtile 展开的。
  • 为什么一些算子更接近 memory-bound,以及 fusion 为什么有效。

继续往前走,下一步自然会遇到一个更有代表性的算子:matmul。它的重要性不只是因为矩阵乘法常见,而是因为它代表了另一类完全不同的性能直觉。

和向量加法、softmax 这类更偏带宽受限的算子相比,matmul 往往具备更高的数据复用空间和更高的算术强度。因此,讨论 matmul 时,重点会逐渐从“少搬几次数据”转向“怎样让读进来的数据多参与几次有效计算”。

这篇文章围绕三个问题展开:

  • 为什么 matmul 是理解 Triton 进阶编程模型的合适入口。
  • autotune 究竟在选择什么,而不是简单“自动调参”。
  • 手写 Triton kernel 与 PyTorch、torch.compile 的关系到底是什么。

为什么 matmul 和前面的例子不同

先把最核心的区别说清楚。

向量加法的执行模式是:

  • 读一次 x
  • 读一次 y
  • 做一次加法
  • 写一次 out

它对每次全局内存访问的计算利用极低,因此更接近 memory-bound

matmul 的思路完全不同。对 C = A @ B 来说,一个从 AB 中读进来的数据块,并不是只参与一次计算,而是会在一个 tile 内被反复用于乘加累积。只要分块方式合理,同一批加载进来的数据就能被复用多次。

这意味着,matmul 的优化重点会转向:

  • tile 如何切分;
  • K 维如何分段推进;
  • 累加器如何组织;
  • 哪组 block 参数更适合当前矩阵形状。

也正因为如此,matmul 非常适合作为理解 autotune 的第一个真实场景。

先建立一个 tile 视角

以最常见的二维矩阵乘法为例:

1
C[M, N] = A[M, K] @ B[K, N]

在 Triton 里,一个常见思路是让单个 program 负责 C 中的一个子块,而不是只负责单个元素。

可以把它理解成下面这个过程:

1
2
3
4
5
6
7
8
9
一个 program 负责输出矩阵中的一个 BLOCK_M x BLOCK_N 子块

它不会一次把整个 K 维都算完,
而是沿着 K 维每次取一个 BLOCK_K 厚度的切片:

acc += A_tile @ B_tile
acc += A_tile @ B_tile
acc += A_tile @ B_tile
...

因此,matmul kernel 的关键不再只是“当前 program 是谁”,还包括:

  • 当前 program 负责输出矩阵的哪个 tile;
  • 它沿 K 维每次取多厚的切片;
  • 这组 tile 参数是否能在当前 GPU 和当前矩阵形状下带来更高效率。

一个教学型 matmul kernel

下面这份代码只保留最核心结构,不追求达到生产级性能。它的目标是把 program、二维 tile 和 autotune 的关系讲清楚。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import triton
import triton.language as tl


@triton.autotune(
configs=[
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
num_warps=4,
num_stages=2,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32},
num_warps=4,
num_stages=3,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32},
num_warps=4,
num_stages=3,
),
],
key=["M", "N", "K"],
)
@triton.jit
def matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
grid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // grid_n
pid_n = pid % grid_n

offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

for k_start in range(0, K, BLOCK_K):
offs_k = k_start + tl.arange(0, BLOCK_K)

a = tl.load(
a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak,
mask=(offs_m[:, None] < M) & (offs_k[None, :] < K),
other=0.0,
)
b = tl.load(
b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn,
mask=(offs_k[:, None] < K) & (offs_n[None, :] < N),
other=0.0,
)

acc += tl.dot(a, b)

c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))


def triton_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
M, K = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float32)

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),
)

matmul_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
)
return c

这段代码里,最重要的不是每个 API,而是三个结构性设计:

  • 用一维 program_id 解码出二维 tile 坐标。
  • 沿 K 维循环推进,每次处理一个 BLOCK_K 厚度切片。
  • 把不同 tile 参数和执行参数交给 autotune 选择。

这里的 program 到底负责什么

在向量加法里,一个 program 负责一段一维数据。到了 matmul,一个 program 的职责变成了:

  • 负责输出矩阵中的一个 BLOCK_M x BLOCK_N 子块;
  • 沿着 K 维逐段读取输入 tile;
  • 把多个局部乘加结果累积到 acc 中;
  • 最后把这个局部 tile 写回输出矩阵。

也就是说,当前 program 不再直接对应“一个元素”或“一段向量”,而是对应“输出矩阵中的一个二维块”。

这一步是从基础 Triton 编程模型走向更真实算子的关键转折。

为什么要有 acc

acc 是局部累加器。matmul 不会在每次 tl.dot(a, b) 后立刻把中间结果写回全局内存,而是先在更局部的上下文中不断累积,等所有 K 维切片都处理完,再统一写回。

这种设计有两个直接好处:

  • 避免把每一轮 K 维局部结果都落回全局内存。
  • 让单次加载进来的 tile 数据参与更多乘加计算。

从性能角度看,这正是 matmul 和前面 memory-bound 示例最大的差异之一。它更强调复用与累积,而不是只做一次轻量操作就立即写回。

autotune 到底在选什么

很多人第一次看到 @triton.autotune,会把它简单理解成“自动调参”。这不算错,但太粗。

更准确的说法是:

autotune 在给定输入规模下,对多组 kernel 配置进行实测,然后缓存其中表现最好的那一组。

这里的“配置”通常包括:

  • BLOCK_M
  • BLOCK_N
  • BLOCK_K
  • num_warps
  • num_stages

这些参数共同决定了:

  • 一个 program 处理多大的 tile;
  • K 维推进的粒度;
  • 并行协作粒度;
  • 软件流水深度;
  • 寄存器和其他资源的占用方式。

因此,autotune 本质上不是魔法,而是把“凭经验猜参数”改造成“对候选配置做系统化 benchmark”。

为什么不同形状可能需要不同配置

假设有三种情况:

  • M 很大而 N 较小;
  • N 很大而 M 较小;
  • MNK 都比较均衡;

这三种输入下,最合适的 tile 形状并不一定相同。更大的 BLOCK_M 可能让行方向复用更充分,但也可能带来更高资源压力;更大的 BLOCK_N 也是同理。

因此,原稿里给出的三组候选配置本质上是在覆盖不同的 tile 偏好:

  • 64 x 64 x 32
  • 128 x 64 x 32
  • 64 x 128 x 32

这些配置不是“谁一定更先进”,而是在不同输入规模和不同硬件条件下各有适用区间。

num_warpsnum_stages 不只是附属参数

很多入门文章只强调 BLOCK_MBLOCK_NBLOCK_K,而把 num_warpsnum_stages 当成附带选项。实际上它们同样会明显影响性能。

num_warps

它决定一个 program 使用多少 warp 参与执行。warp 太少,可能无法把并行度拉起来;warp 太多,又可能导致资源占用上升,反而降低整体效率。

num_stages

它通常可以理解为软件流水深度,影响数据预取与计算重叠的程度。更高的 num_stages 可能带来更好的延迟隐藏,但代价往往是更大的资源压力。

这也是为什么真正做调优时,autotune 往往会把这两个参数与 block 参数一起纳入搜索空间。

教学型 matmul 为什么未必比 PyTorch 快

这一点需要明确说明。一个教学型 Triton matmul kernel,即使结构正确,也很可能明显慢于 PyTorch 直接调用的成熟实现。

原因并不神秘。PyTorch 在矩阵乘法场景里通常会调用高度优化的底层库,例如 cuBLAS。相比之下,教学型 Triton kernel 往往省略了大量真正影响峰值性能的工程细节,例如:

  • 更复杂的数据布局与 swizzle;
  • 更成熟的预取与流水设计;
  • 更深的 autotune 搜索空间;
  • 针对特定硬件的细化优化。

因此,这一节的目标不是“立刻写出一个超过 cuBLAS 的 matmul”,而是把 Triton 里 tile、局部累积和 autotune 的结构关系讲清楚。

Triton 和 PyTorch 的关系,不是二选一

很多人在接触 Triton 后,会自然问一个问题:既然 torch.compile 也会生成 Triton kernel,那手写 Triton 还有什么意义?

这个问题不能用一句话回答,但可以先明确两层关系。

第一层:PyTorch 会在部分场景下自动生成 Triton kernel

当使用 torch.compile 且后端为 inductor 时,PyTorch 在某些图优化路径里会生成 Triton kernel。这意味着:

  • 不是所有 Triton kernel 都需要手写;
  • 对标准算子组合,编译器有机会自动下沉并生成较优实现;
  • 这条路径更偏“让常规 PyTorch 代码自动获得更好的执行计划”。

第二层:手写 Triton 解决的是编译器默认路径之外的问题

当你需要:

  • 明确控制 tile 设计;
  • 为特殊形状或特殊数据流编写定制 kernel;
  • 对某个热点路径做更激进的局部优化;

手写 Triton 仍然有明确价值。它不是为了替代 PyTorch,而是为了在默认编译路径不足以满足目标时,提供更低层的可控性。

torch.compile 和手写 Triton 怎样协作

一个更实际的理解是:

  • torch.compile 负责优化它能识别、能融合、能下沉的标准图。
  • 手写 Triton 负责你明确想接管的局部热点。

二者并不是互斥关系。一个完整的推理 pipeline 很可能同时包含:

  • 一部分由 PyTorch 图编译自动优化的普通算子;
  • 一部分由手写 Triton kernel 接管的热点路径;
  • 其余仍由标准库或底层后端处理的执行部分。

因此,在工程里更值得问的问题不是“到底该用哪一个”,而是“当前这个热点,默认编译路径已经足够了吗”。

看待这篇 matmul 的正确方式

如果只把这一篇看成“写出了一个矩阵乘法 kernel”,价值会比较有限。更合理的理解方式是:

  • 它把一维分块扩展成了二维 tile。
  • 它第一次引入了沿 K 维迭代累积的结构。
  • 它第一次引入了 autotune 作为系统化选配置的机制。
  • 它把 Triton kernel 与 PyTorch 编译路径的关系放回了同一张图里。

这几件事合起来,才是本篇真正的核心。

常见误区

误区一:把 matmul 的性能全部归因于 tl.dot

tl.dot 很重要,但它只是局部乘加原语。真正决定整体效率的,仍然是 tile 设计、K 维推进方式和配置选择。

误区二:把 autotune 当成最终答案

autotune 只能在给定候选配置中选一个当前更优的结果。候选空间本身设计得不好,最终结果也不会理想。

误区三:因为 PyTorch 更快,就认为教学型 Triton kernel 没有意义

教学型 kernel 的目标是建立结构认知,而不是直接替代成熟库。只要这个边界没有混淆,这类示例仍然非常有价值。

结论

matmul 是理解 Triton 进阶模型的关键案例,因为它把注意力从一维块处理推进到了二维 tile、局部累积和配置搜索。

这一篇最重要的结论不是“如何立刻写出最快的 matmul”,而是:

  • 一个 program 可以负责输出矩阵中的一个二维 tile;
  • K 维循环是 matmul kernel 的核心骨架;
  • autotune 的本质是对候选配置做实测选择;
  • 手写 Triton 与 PyTorch 编译路径是协作关系,不是简单替代关系。

下一篇会继续往前走,把注意力从通用 matmul 转向推理路径里更常见的热点算子,讨论什么样的算子值得单独用 Triton 接管,以及判断依据是什么。

系列导航