CUDA系统拆解-15-Tensor Core、WMMA 与 MMA:矩阵乘指令路径怎么打通
CUDA系统拆解-15-Tensor Core、WMMA 与 MMA:矩阵乘指令路径怎么打通
本文是「CUDA系统拆解」系列第 15 篇。
系列导读:CUDA系统拆解-00-导读:从编程模型到 AI 推理系统的学习路线
上一篇:CUDA系统拆解-14-PTX、SASS 与编译链:CUDA代码如何落到机器指令
下一篇:CUDA系统拆解-16-CUTLASS、Triton、cuBLAS 与 FlashAttention:高性能实现都在做什么
1. 这篇解决什么问题
这一篇要讲清 5 件事:
- Tensor Core 到底是什么,它和普通 CUDA Core 的区别是什么。
Tensor Core、MMA、WMMA三者分别处在哪一层。- 为什么 Tensor Core 对 AI 推理特别重要,但不是所有 kernel 都能自动吃满它。
fragment和 warp 级矩阵计算到底该怎么理解。- 一个最小
WMMA骨架在做什么,尤其是mma_sync的含义。
如果这篇只记住一句话,那就是:
Tensor Core 提供的是矩阵乘加的硬件峰值能力;真正性能能不能出来,还要看数据类型、tile 组织、内存 feeding 和整体数据流是否都围着它设计。
2. 先记住的核心结论
Tensor Core是专门做矩阵乘加的硬件执行单元,不是“更强的 CUDA Core”。MMA是更低层的矩阵乘加指令语义,表达的是D = A * B + C这类 warp 级 tile 计算。WMMA是 CUDA C++ 提供的较高层 warp 级编程接口,用来更方便地走 Tensor Core 路径。- Tensor Core 的基本工作粒度不是单线程标量运算,而是 warp 协作下的小块矩阵计算。
fragment不是普通二维数组,而是一个逻辑 tile 在 warp 多个线程寄存器中的分布式表示。- 只有算子模式、精度、shape、layout、对齐和 feeding 都比较合适时,Tensor Core 才能真正带来明显加速。
3. 正文讲解
3.1 Tensor Core 到底是什么
Tensor Core 本质上是 GPU 上专门为矩阵乘加设计的硬件单元。它最擅长的不是通用标量计算,而是高度规则、重复很多、可批量执行的小块矩阵乘加。
可以把它抽象成:
[
D = A \times B + C
]
这里重点不是公式本身,而是它的计算粒度:
- 不是一个线程做一次
a * b + c - 而是一组线程协作完成一个 tile 的矩阵乘加
- 输入和输出通常都是分块组织的
- 底层硬件和指令路径都围绕这种模式做了专门优化
所以 Tensor Core 快,不是因为“单个线程更强”,而是因为它把深度学习里最常见、最规则的一类重计算单独拉出来,用专门硬件做高吞吐处理。
3.2 为什么 AI 推理特别需要 Tensor Core
AI 推理里最重的部分,很多都能映射成 GEMM 或近似 GEMM:
- 线性层
- attention 里的
QK^T - attention 里的
PV - 很多卷积实现
- MoE 中的专家 MLP
这类计算有两个特点:
- 乘加密度高
- 结构比较规则
这正好适合 Tensor Core。再加上推理很依赖低精度计算,比如 FP16、BF16、TF32、INT8、FP8,而 Tensor Core 又正是这些低精度高吞吐路径的核心硬件支点,所以它和 AI 推理天然绑得很紧。
但要注意一个常见误区:推理系统里不是所有阶段都同样受益。比如很多在线 LLM decode 阶段常常更偏 memory-bound,这时候即使 Tensor Core 峰值很高,实际收益也可能被访存和调度开销吃掉。
3.3 Tensor Core、MMA、WMMA 三者的关系
这三者要分层理解:
Tensor Core:硬件层MMA:更低层的矩阵乘加指令语义WMMA:CUDA C++ 暴露出来的较高层编程接口
一句话记忆:
Tensor Core 是硬件,MMA 是低层运算语义,WMMA 是程序员更容易使用的接口。
WMMA 的价值是把 warp 级矩阵乘加这件事抽象成了几个高层操作:
- 定义 fragment
- 加载 tile
- 调用
mma_sync - 存回结果
但 WMMA 不是极致性能的终点。很多高性能库和定制 kernel 会进一步下探到更底层的 mma 指令路径,因为那里对 tile、layout、流水和寄存器分布的控制更细。
3.4 为什么是 warp 级计算,而不是 thread 级计算
普通 CUDA 编程里,你经常把 thread 当成主要思考单位;到了 Tensor Core 路径,更合适的思考单位通常是 warp。
原因很直接:一个线程自己并不能高效完成一个矩阵 tile 的乘加。更典型的方式是:
- 一个 warp 负责一个或多个输出 tile
- warp 内线程分工装载输入块
- 矩阵乘加在 Tensor Core 路径上执行
- warp 内线程共同维护累加结果
这就是 Warp Matrix Multiply and Accumulate 这个名字的含义。它已经在名字里告诉你:这不是 thread-level API,而是 warp-level API。
3.5 fragment 到底是什么
fragment 最容易被误解成“一个普通小矩阵”。这种理解不完全错,但不够准确。
更准确地说:
fragment 是一个逻辑矩阵 tile 在 warp 多个线程寄存器中的分布式表示。
这意味着:
- 你逻辑上在操作一个 tile
- 物理上它并不是整块连续地放在某个地方
- 数据按照硬件定义的映射方式分散在 warp 内各线程寄存器里
所以 fragment 不是你能随意按二维数组访问的数据结构。你通常需要通过专门接口来操作它:
load_matrix_syncmma_syncstore_matrix_sync
这个抽象的目的,是把“适合 Tensor Core 的数据组织方式”隐藏在 API 和编译器后面,让程序员不用手工管理最底层的寄存器映射细节。
3.6 一个最小 WMMA 骨架
下面这段代码不是完整工程代码,而是保留理解主线的最小骨架:
1 |
|
看这段代码时,不要先陷进每个模板参数,而要先抓住主流程:
- 定义 A、B、C 对应的 fragment
- 初始化 accumulator
- 按
K维循环加载输入 tile - 调用
wmma::mma_sync做乘加 - 把结果 tile 写回全局内存
这里最关键的是 wmma::mma_sync。它表达的不是普通标量乘加,而是:
warp 协作地对一个 tile 执行 C = A * B + C。
再注意一个常见设计:
- 输入常是
half/bf16 - 累加常是
float
这就是混合精度的核心直觉:输入用低精度提高吞吐,累加用更高精度维持数值稳定性。
3.7 为什么不是所有算子都能自动吃满 Tensor Core
这是工程里最重要的现实问题之一。Tensor Core 很强,但它不是“打开就自动满速”的开关。
通常要同时满足几类条件:
- 算子模式适合映射到矩阵乘加
- 数据类型落在支持范围内
- tile shape 和维度比较友好
- layout、对齐和加载路径比较合适
- 周围的数据搬运、变形、mask、归约、epilogue 开销不能太大
所以以下情况常常会削弱收益:
- kernel 本身更偏
memory-bound - shape 太小或太碎
- 动态形状过多
- layout 变换成本高
- 访存 feeding 跟不上计算峰值
这也是为什么真正高性能的 GEMM、attention、FlashAttention、TensorRT-LLM kernel,不是“加一条 Tensor Core 指令”就结束,而是会把多级 tiling、shared memory staging、寄存器累加、双缓冲和流水线一起设计好。
4. 和 AI 推理的关系
理解 Tensor Core 之后,你会更容易看懂很多推理优化为什么长那样:
- 为什么推理框架如此重视
FP16、BF16、INT8、FP8 - 为什么高性能 GEMM 会反复强调 tile、layout、pipeline
- 为什么 FlashAttention 不只是一个“调用 Tensor Core 的 kernel”,而是同时在解决访存和 Tensor Core feeding
- 为什么量化之后还要继续做 fusion、dequant 优化和数据流重排
从系统角度看,Tensor Core 给的是高计算峰值;推理工程真正要做的,是把这个峰值尽可能变成稳定的实际吞吐。
5. 常见误区
Tensor Core不是“更强的 CUDA Core”,而是更专用的矩阵乘加硬件。- 不是只要用了
FP16,程序就一定会自动走到 Tensor Core 的最佳路径。 WMMA不是 Tensor Core 的全部,它只是一个较高层的编程入口。- 有了 Tensor Core,不代表访存、layout、shared memory、流水线这些问题就不重要了;通常反而更重要。
- 不是所有 GEMM 或 attention kernel 都能同样吃满 Tensor Core,shape、精度、对齐和 feeding 条件都很关键。
6. 复习自测
- Tensor Core、MMA、WMMA 分别处在哪一层,各自解决什么问题?
- 为什么 Tensor Core 的自然思考单位是 warp,而不是单线程?
fragment为什么不能当普通二维数组来理解?wmma::mma_sync在逻辑上做了什么?为什么常配合低精度输入和高精度累加?- 为什么 Tensor Core 峰值很高,但某些推理 kernel 实际收益却不明显?
- 如果一个 kernel 已经偏
memory-bound,你应该优先检查 Tensor Core 使用率,还是优先检查 feeding 和访存路径?为什么?

