PyTorch推理工程:06 torch.compile、导出与 ONNX
PyTorch 推理工程(06):torch.compile、导出与 ONNX
1. 本节定位
Eager 执行便于调试,但部署与融合优化常依赖图表示:torch.compile 在 PyTorch 运行时内做图捕获与优化;torch.export 与 ONNX 等路径则面向跨运行时交付。本篇说明各路径的适用边界、动态 shape 在导出中的常见失效形式,以及 eager 可通过而 compile/export 失败的一类原因(如控制流与 Python 侧副作用)。
2. 全局视图
在展开讲之前,先把四种东西的关系说清楚,之后细节才能挂上去。
1 | ┌─────────────────────────────────────────────────────────────────┐ |
记住这个层次,后面的内容都是在展开它。
3. 为什么 eager 模式不够用:从"为什么需要图"说起
Eager 模式的优势
常见的逐行执行即 Eager 模式:
1 | y = model(x) # Python 一行一行执行,GPU 指令一条一条发出 |
这很适合:
- 调试(加一行 print 就能看中间结果)
- 研究(随时改网络结构)
- 动态控制流(if/for/while 完全自由)
Eager 模式的代价
但在推理部署时,eager 的灵活性变成了负担:
1 | Eager 执行的每一步: |
更关键的是:当每个 op 单独执行时,很多跨 op 的优化根本做不了。
举个例子:
1 | x = x * 2 # op 1:乘法 |
如果这三步分开执行,每次都要把数据从 GPU 显存读出来、处理完再写回去(3 次读写)。
若将上述三个 op 融合(fuse) 为单一 kernel,可将对全局内存的读写次数降为约一次;此即 kernel fusion,带宽受限场景下收益显著。
但 eager 模式下,PyTorch 不知道下一行是什么,没法做这种"看几步之后"的优化。
动机:在图表示上先做全局观察,再统一做融合、常量折叠等变换。
4. torch.compile():在 PyTorch 内部做图优化
一句话理解
torch.compile()在保留近似 Eager 调用方式的前提下,对可捕获子图做编译与后端优化。
最小例子
1 | import torch |
注意:
torch.compile(model)返回一个新的可调用对象,用法和原来的model一样eval()和inference_mode()仍然需要写- 第一次调用会有编译开销(可能要几秒),后续调用才是优化后的速度
调用之后 PyTorch 做了什么
1 | `model(x)` 调用 |
TorchDynamo 负责将 Python 前向捕获为图;TorchInductor 负责由图生成融合后的 kernel(名称便于对照官方文档与日志)。
5. Graph Break:为什么 compile 有时没用
torch.compile() 在运行期分析前向代码,并将可图化片段合并为更大子图。
但它不是万能的。有些 Python 写法没办法被安全地图化,遇到这种地方就会"断图"(graph break):
1 | 示例 `forward()`: |
断图意味着:
- 被分成多个小图分别编译优化
- 断点处退回到 eager 执行
- 整体优化收益可能下降
常见触发 graph break 的写法
1 | def forward(self, x): |
如何诊断 graph break:设置环境变量
TORCH_LOGS=graph_breaks运行,PyTorch 会打印出断图的位置和原因。
6. Guard Failure:为什么 compile 后换个 shape 又慢了
torch.compile() 编译后会做"缓存"——下次跑同样的代码,如果输入条件满足,就直接用缓存的编译结果。
但这份编译结果是有前提条件的(叫做 guard),比如:
- “batch size 是 32”
- “dtype 是 float32”
- “x 在 cuda:0 上”
如果下次调用时某个条件不满足,就会发生 guard failure:缓存失效,重新编译。
1 | compiled_model = torch.compile(model) |
实际推理中的影响
若各请求 batch 变化剧烈且未妥善声明动态维,torch.compile() 可能频繁重编译,稳态性能或劣于 Eager。
解决方案:声明某些维度是动态的:
1 | compiled_model = torch.compile(model, dynamic=True) |
7. torch.compile() 的实际收益与边界
什么时候收益明显
- 大模型:参数多,单次前向时间长,优化空间大
- 重复执行同一 shape:编译缓存能充分利用
- 计算密集:大量 matmul/conv,kernel fusion 效果好
- backend 对这个 GPU 支持好
什么时候收益不明显
- 首次调用:有编译开销,通常慢几秒
- 频繁变 shape:反复触发 guard failure 和重编译
- 大量 graph break:编译出的 graph 太碎,收益有限
- 小模型、小 batch:Python 开销本来就不是瓶颈
工程原则:
torch.compile()不是"无代价的加速开关",先 benchmark 再决定用不用,首次调用的慢不代表最终性能。
8. torch.export:更严格的图捕获
如果说 torch.compile() 是"在 PyTorch 里先试着跑快一点",
那么 torch.export 是在做一件更严格的事:
把模型捕获为一个"只有 Tensor 计算,没有 Python 运行时依赖"的规范图。
为什么要这么严格
torch.export 的首要目标是产出可供后续工具链消费的图表示,而非直接优化单次 Eager 延迟。
如果这个图里还夹杂着 Python 变量、动态分支、随机对象引用……后续的优化工具、部署工具根本没法处理。
一个最小例子
1 | import torch |
理解 ExportedProgram
导出的结果不是模型对象,而是一个 ExportedProgram,可以理解为:
1 | ExportedProgram 包含: |
“参数被提升为图的输入"是什么意思?在 eager 里,权重是模型对象的属性,forward 里直接用 self.weight。但在导出图里,权重被当作"这张图需要的输入之一”,和输入 x 一起传入,图里面只有纯粹的张量计算。这样图就完全自包含了。
9. 动态 shape:为什么它是导出的最大难点
默认情况:shape 被固定死
torch.export 默认会把 example inputs 的 shape 当成"固定的"。
1 | example_inputs = (torch.randn(2, 8),) # batch=2, dim=8 |
这意味着导出后的图只对 shape=(2, 8) 的输入有效。换一个 batch size,这个图可能直接失效。
真实推理里哪些维度天然是动态的
1 | 文本模型:(batch_size, seq_len, hidden_dim) |
显式声明动态 shape
1 | from torch.export import Dim |
若导出模型仅接受固定序列长度或固定 batch,常见原因之一是未在导出接口中声明动态维;现象上表现为 shape 被固化,需在导出参数或后续图中显式放开对应维。
10. torch.onnx.export:把模型带到更广泛的生态
ONNX 是什么
ONNX(Open Neural Network Exchange)是一个开放的模型格式标准。很多推理框架都支持 ONNX:
- ONNX Runtime(微软,跨平台,CPU/GPU/NPU)
- TensorRT(NVIDIA,GPU 推理,常通过 ONNX 导入)
- OpenVINO(Intel)
- 各种边缘设备 SDK
把模型导出成 ONNX,就可以脱离 PyTorch 运行环境,用这些专用推理引擎执行。
当前推荐的导出方式
1 | import torch |
dynamo=True 的意思是:走基于 torch.export 的新导出路径(而不是旧的 TorchScript 路径)。这是现在的官方推荐方式。
验证导出是否成功
1 | import onnxruntime as ort |
11. 为什么 ONNX 导出也有动态 shape 问题
和 torch.export 一样,ONNX 导出默认会把 example inputs 的 shape 固定死。
解决方法是在导出时声明 dynamic_shapes(新 API):
1 | from torch.export import Dim |
这样导出的 ONNX 模型,就能接受不同 batch size 的输入了。
12. 四种方式对比一览
| 特性 | Eager | torch.compile() |
torch.export |
ONNX 导出 |
|---|---|---|---|---|
| 主要目标 | 灵活执行 | PyTorch 内提速 | 规范化图捕获 | 跨运行时部署 |
| 需要改代码 | 否 | 几乎不需要 | 可能需要 | 可能需要 |
| 对动态控制流 | 完全支持 | 部分支持 | 有限支持 | 有限支持 |
| 动态 shape | 完全支持 | 通过 dynamic=True |
需显式声明 | 需显式声明 |
| 首次开销 | 无 | 有(编译时间) | 有(导出时间) | 有(导出时间) |
| 跨框架运行 | 否 | 否 | 否 | 是 |
| 适合场景 | 开发调试 | 在线推理加速 | 部署前准备 | 跨平台部署 |
13. 为什么"eager 能跑"不等于"compile/export 一定成功"
这是非常重要的工程现实。
Eager 的优势正是它的灵活性:任何 Python 写法都能运行,控制流完全自由,“边跑边看”。
但 compile/export 要求"提前理解执行逻辑",所以以下写法都可能造成问题:
1 | # 反例: 依赖 Tensor 数值的 Python 控制流 |
结论:
- Eager 适合写法自由
- Compile/Export 适合"更规范、图化友好"的写法
- 两者有时需要在代码上做一些调整才能兼容
14. example_inputs 的重要性
torch.compile、torch.export 与 ONNX 导出均依赖 example_inputs(或等价占位)以推断张量形状、dtype 与静态结构;随意示例可能导致与线上分布不一致的图或导出失败。
它决定了:
- trace 的执行路径(哪些 if 分支被走到)
- shape 假设(哪些维度被认为是固定的)
- dtype 假设(输入是 float32 还是 float16)
- 哪些 guard 条件会被记录
工程原则:
example_inputs宜接近线上典型 shape/dtype;若线上为动态 batch,须在导出 API 中声明动态维,否则易在导出或下游运行时失败。
15. 常见坑速查
| 坑 | 现象 | 解决方向 |
|---|---|---|
把 compile 当"免费加速" |
首次慢,频繁重编译 | 先 benchmark,加 dynamic=True |
| graph break 太多 | 编译收益小 | 避免数据依赖控制流,用 TORCH_LOGS=graph_breaks 诊断 |
| 没有声明 dynamic shape | 导出后只能接受固定 shape | 用 Dim 显式声明哪些维度是动态的 |
| example_inputs 不代表真实场景 | 导出图对真实输入失效 | 让 example_inputs 贴近真实分布 |
| eager 能跑就以为 export 没问题 | 导出时报控制流相关错误 | 检查 forward 里的 Python 控制流写法 |
| 用旧 TorchScript 路径 ONNX 导出 | 缺少动态 shape 等新特性 | 改用 dynamo=True |
16. 四种技术的工程选择逻辑
想在 PyTorch 内部快一点
→ 用 torch.compile(model),比较 compile 前后的 benchmark 结果
想得到规范化图,准备后续分析或部署工具链
→ 用 torch.export.export(model, example_inputs)
想把模型给 ONNX Runtime 或 TensorRT 等工具
→ 用 torch.onnx.export(..., dynamo=True)
想最快速度评估模型效果,或者做动态结构实验
→ 直接用 eager,compile/export 留到真正需要部署时再考虑
17. 面试常见问题
Q:torch.compile() 是做什么的?
A:它用 TorchDynamo 分析 eager 代码,把其中能图化的部分提取成计算图,再交给编译 backend(如 TorchInductor)做 kernel fusion 等优化,缓存编译结果供后续复用。目标是在保留 eager 使用方式的前提下提升执行效率。
Q:什么是 graph break?
A:当 compile 遇到无法图化的 Python 写法(如依赖 Tensor 值的 if 分支、调用 .item() 等),就会在那个位置断图,前后分别编译成独立的小图,断点处回退 eager 执行。graph break 太多会导致 compile 收益降低。
Q:guard failure 是什么?
A:compile 编译后会缓存结果,并记录该结果在什么条件下有效(guard)。如果后续调用时条件不满足(比如 shape 变了),缓存失效,需要重新编译。
Q:torch.export 和 torch.compile() 的区别是什么?
A:compile 更偏运行时优化,在 PyTorch 内部提速;export 更偏生成一个严格的、只包含 Tensor 计算的规范图表示,去掉 Python 运行时依赖,适合后续部署工具链使用。
Q:为什么动态 shape 在 export 里是个大问题?
A:export 默认把 example inputs 的 shape 固定死,生成的图只对这一个 shape 有效。真实推理里 batch size 和序列长度通常是动态的,不声明动态 shape 就会导致导出的图对真实输入失效。需要用 Dim 显式声明哪些维度是动态的。
Q:为什么 ONNX 导出现在推荐 dynamo=True?
A:因为新的导出路径(dynamo=True)底层走 torch.export 的逻辑,支持更好的动态 shape 表达、更规范的图表示,是官方目前的推荐方式。旧的 TorchScript 路径功能更受限。
18. 思考题
练习 1:观察 torch.compile() 的首次编译开销
1 | import torch |
思考:第 1 次和第 2 次之后的时间差多少?为什么?
练习 2:触发 graph break
写一个包含数据依赖分支的 forward:
1 | class BranchModel(nn.Module): |
然后运行:
1 | TORCH_LOGS=graph_breaks python your_script.py |
思考:日志里提示 graph break 的原因是什么?怎样改写才能避免?
练习 3:最小 torch.export
1 | import torch |
思考:graph 里 op 的名字是什么风格(
aten.mm.default这类)?和 eager 时调用nn.Linear有什么区别?
练习 4:最小 ONNX 导出和验证
1 | import torch |
思考:如果现在用 batch=4 运行 onnx_runtime,会报错吗?怎么修改导出以支持动态 batch?
练习 5:声明动态 shape
在练习 4 的基础上,加入动态 batch 声明:
1 | from torch.export import Dim |
思考:用不同 batch size(如 1、4、64)分别运行 ONNX Runtime,结果正确吗?
19. 本节要点与自检
- 理解 eager / compile / export / ONNX 四者的角色和层次关系
- 理解 compile 的加速原理(图化 + kernel fusion),以及 graph break 和 guard failure 意味着什么
- 理解 export 为什么比 compile 更严格,
ExportedProgram包含什么 - 知道动态 shape 默认不被支持,如何用
Dim声明 - 会写基本的
torch.compile、torch.export、torch.onnx.export代码 - 明白为什么 eager 能跑不等于 compile/export 一定成功
- 知道
example_inputs的重要性,不能随意给
20. 小结
torch.compile()是在 PyTorch 里让相同代码跑更快,torch.export是把模型变成严格的规范图,ONNX 是把这个图带到更广泛的部署世界——三者是递进的工程链路,不是互相替代的关系。

