CUDA系统拆解-15-Tensor Core、WMMA 与 MMA:矩阵乘指令路径怎么打通

本文是「CUDA系统拆解」系列第 15 篇。
系列导读:CUDA系统拆解-00-导读:从编程模型到 AI 推理系统的学习路线
上一篇:CUDA系统拆解-14-PTX、SASS 与编译链:CUDA代码如何落到机器指令
下一篇:CUDA系统拆解-16-CUTLASS、Triton、cuBLAS 与 FlashAttention:高性能实现都在做什么

1. 这篇解决什么问题

这一篇要讲清 5 件事:

  • Tensor Core 到底是什么,它和普通 CUDA Core 的区别是什么。
  • Tensor CoreMMAWMMA 三者分别处在哪一层。
  • 为什么 Tensor Core 对 AI 推理特别重要,但不是所有 kernel 都能自动吃满它。
  • fragment 和 warp 级矩阵计算到底该怎么理解。
  • 一个最小 WMMA 骨架在做什么,尤其是 mma_sync 的含义。

如果这篇只记住一句话,那就是:

Tensor Core 提供的是矩阵乘加的硬件峰值能力;真正性能能不能出来,还要看数据类型、tile 组织、内存 feeding 和整体数据流是否都围着它设计。

2. 先记住的核心结论

  • Tensor Core 是专门做矩阵乘加的硬件执行单元,不是“更强的 CUDA Core”。
  • MMA 是更低层的矩阵乘加指令语义,表达的是 D = A * B + C 这类 warp 级 tile 计算。
  • WMMA 是 CUDA C++ 提供的较高层 warp 级编程接口,用来更方便地走 Tensor Core 路径。
  • Tensor Core 的基本工作粒度不是单线程标量运算,而是 warp 协作下的小块矩阵计算。
  • fragment 不是普通二维数组,而是一个逻辑 tile 在 warp 多个线程寄存器中的分布式表示。
  • 只有算子模式、精度、shape、layout、对齐和 feeding 都比较合适时,Tensor Core 才能真正带来明显加速。

3. 正文讲解

3.1 Tensor Core 到底是什么

Tensor Core 本质上是 GPU 上专门为矩阵乘加设计的硬件单元。它最擅长的不是通用标量计算,而是高度规则、重复很多、可批量执行的小块矩阵乘加。

可以把它抽象成:

[
D = A \times B + C
]

这里重点不是公式本身,而是它的计算粒度:

  • 不是一个线程做一次 a * b + c
  • 而是一组线程协作完成一个 tile 的矩阵乘加
  • 输入和输出通常都是分块组织的
  • 底层硬件和指令路径都围绕这种模式做了专门优化

所以 Tensor Core 快,不是因为“单个线程更强”,而是因为它把深度学习里最常见、最规则的一类重计算单独拉出来,用专门硬件做高吞吐处理。

3.2 为什么 AI 推理特别需要 Tensor Core

AI 推理里最重的部分,很多都能映射成 GEMM 或近似 GEMM:

  • 线性层
  • attention 里的 QK^T
  • attention 里的 PV
  • 很多卷积实现
  • MoE 中的专家 MLP

这类计算有两个特点:

  • 乘加密度高
  • 结构比较规则

这正好适合 Tensor Core。再加上推理很依赖低精度计算,比如 FP16BF16TF32INT8FP8,而 Tensor Core 又正是这些低精度高吞吐路径的核心硬件支点,所以它和 AI 推理天然绑得很紧。

但要注意一个常见误区:推理系统里不是所有阶段都同样受益。比如很多在线 LLM decode 阶段常常更偏 memory-bound,这时候即使 Tensor Core 峰值很高,实际收益也可能被访存和调度开销吃掉。

3.3 Tensor Core、MMA、WMMA 三者的关系

这三者要分层理解:

  • Tensor Core:硬件层
  • MMA:更低层的矩阵乘加指令语义
  • WMMA:CUDA C++ 暴露出来的较高层编程接口

一句话记忆:

Tensor Core 是硬件,MMA 是低层运算语义,WMMA 是程序员更容易使用的接口。

WMMA 的价值是把 warp 级矩阵乘加这件事抽象成了几个高层操作:

  • 定义 fragment
  • 加载 tile
  • 调用 mma_sync
  • 存回结果

WMMA 不是极致性能的终点。很多高性能库和定制 kernel 会进一步下探到更底层的 mma 指令路径,因为那里对 tile、layout、流水和寄存器分布的控制更细。

3.4 为什么是 warp 级计算,而不是 thread 级计算

普通 CUDA 编程里,你经常把 thread 当成主要思考单位;到了 Tensor Core 路径,更合适的思考单位通常是 warp。

原因很直接:一个线程自己并不能高效完成一个矩阵 tile 的乘加。更典型的方式是:

  • 一个 warp 负责一个或多个输出 tile
  • warp 内线程分工装载输入块
  • 矩阵乘加在 Tensor Core 路径上执行
  • warp 内线程共同维护累加结果

这就是 Warp Matrix Multiply and Accumulate 这个名字的含义。它已经在名字里告诉你:这不是 thread-level API,而是 warp-level API。

3.5 fragment 到底是什么

fragment 最容易被误解成“一个普通小矩阵”。这种理解不完全错,但不够准确。

更准确地说:

fragment 是一个逻辑矩阵 tile 在 warp 多个线程寄存器中的分布式表示。

这意味着:

  • 你逻辑上在操作一个 tile
  • 物理上它并不是整块连续地放在某个地方
  • 数据按照硬件定义的映射方式分散在 warp 内各线程寄存器里

所以 fragment 不是你能随意按二维数组访问的数据结构。你通常需要通过专门接口来操作它:

  • load_matrix_sync
  • mma_sync
  • store_matrix_sync

这个抽象的目的,是把“适合 Tensor Core 的数据组织方式”隐藏在 API 和编译器后面,让程序员不用手工管理最底层的寄存器映射细节。

3.6 一个最小 WMMA 骨架

下面这段代码不是完整工程代码,而是保留理解主线的最小骨架:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <mma.h>
using namespace nvcuda;

__global__ void wmma_gemm(const half* A, const half* B, float* C,
int M, int N, int K) {
int warpM = blockIdx.y;
int warpN = blockIdx.x;

if ((warpM + 1) * 16 > M || (warpN + 1) * 16 > N) return;

wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;

wmma::fill_fragment(c_frag, 0.0f);

for (int k = 0; k < K; k += 16) {
const half* a_ptr = A + warpM * 16 * K + k;
const half* b_ptr = B + k * N + warpN * 16;

wmma::load_matrix_sync(a_frag, a_ptr, K);
wmma::load_matrix_sync(b_frag, b_ptr, N);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}

float* c_ptr = C + warpM * 16 * N + warpN * 16;
wmma::store_matrix_sync(c_ptr, c_frag, N, wmma::mem_row_major);
}

看这段代码时,不要先陷进每个模板参数,而要先抓住主流程:

  1. 定义 A、B、C 对应的 fragment
  2. 初始化 accumulator
  3. K 维循环加载输入 tile
  4. 调用 wmma::mma_sync 做乘加
  5. 把结果 tile 写回全局内存

这里最关键的是 wmma::mma_sync。它表达的不是普通标量乘加,而是:

warp 协作地对一个 tile 执行 C = A * B + C

再注意一个常见设计:

  • 输入常是 half / bf16
  • 累加常是 float

这就是混合精度的核心直觉:输入用低精度提高吞吐,累加用更高精度维持数值稳定性。

3.7 为什么不是所有算子都能自动吃满 Tensor Core

这是工程里最重要的现实问题之一。Tensor Core 很强,但它不是“打开就自动满速”的开关。

通常要同时满足几类条件:

  • 算子模式适合映射到矩阵乘加
  • 数据类型落在支持范围内
  • tile shape 和维度比较友好
  • layout、对齐和加载路径比较合适
  • 周围的数据搬运、变形、mask、归约、epilogue 开销不能太大

所以以下情况常常会削弱收益:

  • kernel 本身更偏 memory-bound
  • shape 太小或太碎
  • 动态形状过多
  • layout 变换成本高
  • 访存 feeding 跟不上计算峰值

这也是为什么真正高性能的 GEMM、attention、FlashAttention、TensorRT-LLM kernel,不是“加一条 Tensor Core 指令”就结束,而是会把多级 tiling、shared memory staging、寄存器累加、双缓冲和流水线一起设计好。

4. 和 AI 推理的关系

理解 Tensor Core 之后,你会更容易看懂很多推理优化为什么长那样:

  • 为什么推理框架如此重视 FP16BF16INT8FP8
  • 为什么高性能 GEMM 会反复强调 tile、layout、pipeline
  • 为什么 FlashAttention 不只是一个“调用 Tensor Core 的 kernel”,而是同时在解决访存和 Tensor Core feeding
  • 为什么量化之后还要继续做 fusion、dequant 优化和数据流重排

从系统角度看,Tensor Core 给的是高计算峰值;推理工程真正要做的,是把这个峰值尽可能变成稳定的实际吞吐。

5. 常见误区

  • Tensor Core 不是“更强的 CUDA Core”,而是更专用的矩阵乘加硬件。
  • 不是只要用了 FP16,程序就一定会自动走到 Tensor Core 的最佳路径。
  • WMMA 不是 Tensor Core 的全部,它只是一个较高层的编程入口。
  • 有了 Tensor Core,不代表访存、layout、shared memory、流水线这些问题就不重要了;通常反而更重要。
  • 不是所有 GEMM 或 attention kernel 都能同样吃满 Tensor Core,shape、精度、对齐和 feeding 条件都很关键。

6. 复习自测

  • Tensor Core、MMA、WMMA 分别处在哪一层,各自解决什么问题?
  • 为什么 Tensor Core 的自然思考单位是 warp,而不是单线程?
  • fragment 为什么不能当普通二维数组来理解?
  • wmma::mma_sync 在逻辑上做了什么?为什么常配合低精度输入和高精度累加?
  • 为什么 Tensor Core 峰值很高,但某些推理 kernel 实际收益却不明显?
  • 如果一个 kernel 已经偏 memory-bound,你应该优先检查 Tensor Core 使用率,还是优先检查 feeding 和访存路径?为什么?

系列导航