PyTorch 推理工程(06):torch.compile、导出与 ONNX

1. 本节定位

Eager 执行便于调试,但部署与融合优化常依赖图表示:torch.compile 在 PyTorch 运行时内做图捕获与优化;torch.export 与 ONNX 等路径则面向跨运行时交付。本篇说明各路径的适用边界、动态 shape 在导出中的常见失效形式,以及 eager 可通过而 compile/export 失败的一类原因(如控制流与 Python 侧副作用)。


2. 全局视图

在展开讲之前,先把四种东西的关系说清楚,之后细节才能挂上去。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
┌─────────────────────────────────────────────────────────────────┐
│ 四种执行/导出方式 │
│ │
│ Eager Mode │
│ → Python 一行一行执行,灵活好调试,但有 Python 开销 │
│ ↓ 想在 PyTorch 内部加速 │
│ torch.compile() │
│ → 把 eager 代码尽量捕获为图,自动优化,仍然在 PyTorch 里跑 │
│ ↓ 想得到稳定的规范图表示 │
│ torch.export │
│ → 更严格地把模型捕获为只包含 Tensor 计算的规范图 │
│ ↓ 想跨运行时/工具链部署 │
│ torch.onnx.export │
│ → 把模型转成 ONNX 格式,交给 ONNX Runtime 等其他系统执行 │
└─────────────────────────────────────────────────────────────────┘

一句话总结:
compile = 在 PyTorch 里跑更快
export = 把模型变成更"规范"的图
ONNX = 把这个图交给更广泛的生态

记住这个层次,后面的内容都是在展开它。


3. 为什么 eager 模式不够用:从"为什么需要图"说起

Eager 模式的优势

常见的逐行执行即 Eager 模式:

1
y = model(x)   # Python 一行一行执行,GPU 指令一条一条发出

这很适合:

  • 调试(加一行 print 就能看中间结果)
  • 研究(随时改网络结构)
  • 动态控制流(if/for/while 完全自由)

Eager 模式的代价

但在推理部署时,eager 的灵活性变成了负担:

1
2
3
4
5
Eager 执行的每一步:
Python 解释器 → PyTorch 调度器 → CUDA 驱动 → GPU

每一个小 op 都要走这条链路,即使是极小的操作也逃不开。
Python 解释器和调度器的开销,在大量小 op 的场景下很显著。

更关键的是:当每个 op 单独执行时,很多跨 op 的优化根本做不了

举个例子:

1
2
3
x = x * 2       # op 1:乘法
x = x + 1 # op 2:加法
x = torch.relu(x) # op 3:激活

如果这三步分开执行,每次都要把数据从 GPU 显存读出来、处理完再写回去(3 次读写)。

若将上述三个 op 融合(fuse) 为单一 kernel,可将对全局内存的读写次数降为约一次;此即 kernel fusion,带宽受限场景下收益显著。

但 eager 模式下,PyTorch 不知道下一行是什么,没法做这种"看几步之后"的优化。

动机:在图表示上先做全局观察,再统一做融合、常量折叠等变换。


4. torch.compile():在 PyTorch 内部做图优化

一句话理解

torch.compile() 在保留近似 Eager 调用方式的前提下,对可捕获子图做编译与后端优化。

最小例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"

model = nn.Sequential(
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 64),
).to(device)

x = torch.randn(32, 128, device=device)

# 原来:直接用 model
# 现在:用编译后的版本
compiled_model = torch.compile(model)

compiled_model.eval()

with torch.inference_mode():
y = compiled_model(x)

print(y.shape) # torch.Size([32, 64])

注意

  • torch.compile(model) 返回一个新的可调用对象,用法和原来的 model 一样
  • eval()inference_mode() 仍然需要写
  • 第一次调用会有编译开销(可能要几秒),后续调用才是优化后的速度

调用之后 PyTorch 做了什么

1
2
3
4
5
6
7
`model(x)` 调用

TorchDynamo(字节码分析)
↓ 把 Python 代码转化为图
图优化器(TorchInductor 等 backend)
↓ 算子融合、内存布局优化
高效 CUDA kernel

TorchDynamo 负责将 Python 前向捕获为图;TorchInductor 负责由图生成融合后的 kernel(名称便于对照官方文档与日志)。


5. Graph Break:为什么 compile 有时没用

torch.compile() 在运行期分析前向代码,并将可图化片段合并为更大子图。

但它不是万能的。有些 Python 写法没办法被安全地图化,遇到这种地方就会"断图"(graph break):

1
2
3
4
示例 `forward()`:

[Linear] → [ReLU] → ←断了!→ [if branch] → [Linear]
━━━━━━━━━━━━━━━━━━━ graph 1 ↑无法图化 ━━━━━━━ graph 2

断图意味着:

  • 被分成多个小图分别编译优化
  • 断点处退回到 eager 执行
  • 整体优化收益可能下降

常见触发 graph break 的写法

1
2
3
4
5
6
7
8
9
10
11
12
def forward(self, x):
# 反例: 依赖 Tensor 值的 Python 分支
if x.sum() > 0: # x.sum() 的结果要等 GPU 算完才能知道,无法提前图化
return self.fc1(x)
else:
return self.fc2(x)

# 反例: 调用 .item() 把 GPU 值取回 CPU
n = x.shape[0].item() # 触发同步,无法图化后续逻辑

# 示例: 依赖 shape(静态属性),一般没问题
batch_size = x.shape[0] # shape 是编译时已知的,可以图化

如何诊断 graph break:设置环境变量 TORCH_LOGS=graph_breaks 运行,PyTorch 会打印出断图的位置和原因。


6. Guard Failure:为什么 compile 后换个 shape 又慢了

torch.compile() 编译后会做"缓存"——下次跑同样的代码,如果输入条件满足,就直接用缓存的编译结果。

但这份编译结果是有前提条件的(叫做 guard),比如:

  • “batch size 是 32”
  • “dtype 是 float32”
  • “x 在 cuda:0 上”

如果下次调用时某个条件不满足,就会发生 guard failure:缓存失效,重新编译。

1
2
3
4
5
6
7
8
9
10
11
12
13
compiled_model = torch.compile(model)

# 第 1 次调用,shape=(32, 128)
with torch.inference_mode():
y = compiled_model(torch.randn(32, 128, device="cuda")) # 编译(慢)

# 第 2 次,同样 shape
with torch.inference_mode():
y = compiled_model(torch.randn(32, 128, device="cuda")) # 用缓存(快)

# 第 3 次,换了 shape=(64, 128)
with torch.inference_mode():
y = compiled_model(torch.randn(64, 128, device="cuda")) # guard failure!重编译(慢)

实际推理中的影响

若各请求 batch 变化剧烈且未妥善声明动态维,torch.compile() 可能频繁重编译,稳态性能或劣于 Eager。

解决方案:声明某些维度是动态的:

1
2
compiled_model = torch.compile(model, dynamic=True)
# 告诉编译器:batch size 可以变,不要把它固定死

7. torch.compile() 的实际收益与边界

什么时候收益明显

  • 大模型:参数多,单次前向时间长,优化空间大
  • 重复执行同一 shape:编译缓存能充分利用
  • 计算密集:大量 matmul/conv,kernel fusion 效果好
  • backend 对这个 GPU 支持好

什么时候收益不明显

  • 首次调用:有编译开销,通常慢几秒
  • 频繁变 shape:反复触发 guard failure 和重编译
  • 大量 graph break:编译出的 graph 太碎,收益有限
  • 小模型、小 batch:Python 开销本来就不是瓶颈

工程原则torch.compile() 不是"无代价的加速开关",先 benchmark 再决定用不用,首次调用的慢不代表最终性能。


8. torch.export:更严格的图捕获

如果说 torch.compile() 是"在 PyTorch 里先试着跑快一点",

那么 torch.export 是在做一件更严格的事:

把模型捕获为一个"只有 Tensor 计算,没有 Python 运行时依赖"的规范图。

为什么要这么严格

torch.export 的首要目标是产出可供后续工具链消费的图表示,而非直接优化单次 Eager 延迟。

如果这个图里还夹杂着 Python 变量、动态分支、随机对象引用……后续的优化工具、部署工具根本没法处理。

一个最小例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn

class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(8, 16)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(16, 4)

def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))

model = MLP().eval()
example_inputs = (torch.randn(2, 8),) # 注意:是 tuple

# 导出
ep = torch.export.export(model, example_inputs)

print(type(ep))
# <class 'torch.export.exported_program.ExportedProgram'>

print(ep.graph)
# 输出为以规范化 ATen 算子表达的计算图

理解 ExportedProgram

导出的结果不是模型对象,而是一个 ExportedProgram,可以理解为:

1
2
3
4
5
6
7
8
9
10
11
ExportedProgram 包含:
┌─────────────────┐
│ 计算图 │ ← 所有 Tensor 计算,用规范 ATen op 表达
│ (graph) │ 没有 Python 控制流,没有随机依赖
├─────────────────┤
│ 参数/Buffer │ ← 模型的权重,和计算图关联
│ (state_dict) │ 被"提升"为图的输入
├─────────────────┤
│ Shape 约束 │ ← 什么 shape 的输入是合法的
│ (constraints) │
└─────────────────┘

“参数被提升为图的输入"是什么意思?在 eager 里,权重是模型对象的属性,forward 里直接用 self.weight。但在导出图里,权重被当作"这张图需要的输入之一”,和输入 x 一起传入,图里面只有纯粹的张量计算。这样图就完全自包含了。


9. 动态 shape:为什么它是导出的最大难点

默认情况:shape 被固定死

torch.export 默认会把 example inputs 的 shape 当成"固定的"。

1
2
3
4
5
example_inputs = (torch.randn(2, 8),)   # batch=2, dim=8
ep = torch.export.export(model, example_inputs)

# 这个导出的图是为 shape=(2, 8) 设计的
# 如果输入变成 (4, 8)...

这意味着导出后的图只对 shape=(2, 8) 的输入有效。换一个 batch size,这个图可能直接失效。

真实推理里哪些维度天然是动态的

1
2
3
4
5
6
7
文本模型:(batch_size, seq_len, hidden_dim)
↑ ↑
动态 动态(不同长度的句子)

图像模型:(batch_size, C, H, W)

动态(在线服务 batch 大小不固定)

显式声明动态 shape

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torch.export import Dim

# 声明 batch 维度可以在 1~64 之间变化
batch = Dim("batch", min=1, max=64)
seq = Dim("seq_len", min=1, max=512)

dynamic_shapes = {
"x": {0: batch, 1: seq} # 第 0 维是 batch,第 1 维是 seq_len
}

ep = torch.export.export(
model,
example_inputs,
dynamic_shapes=dynamic_shapes
)

若导出模型仅接受固定序列长度或固定 batch,常见原因之一是未在导出接口中声明动态维;现象上表现为 shape 被固化,需在导出参数或后续图中显式放开对应维。


10. torch.onnx.export:把模型带到更广泛的生态

ONNX 是什么

ONNX(Open Neural Network Exchange)是一个开放的模型格式标准。很多推理框架都支持 ONNX:

  • ONNX Runtime(微软,跨平台,CPU/GPU/NPU)
  • TensorRT(NVIDIA,GPU 推理,常通过 ONNX 导入)
  • OpenVINO(Intel)
  • 各种边缘设备 SDK

把模型导出成 ONNX,就可以脱离 PyTorch 运行环境,用这些专用推理引擎执行。

当前推荐的导出方式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch.nn as nn

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(8, 4)

def forward(self, x):
return self.fc(x)

model = SimpleModel().eval()
x = torch.randn(2, 8)

torch.onnx.export(
model,
(x,), # example inputs(tuple)
"model.onnx", # 输出文件名
input_names=["input"],
output_names=["output"],
dynamo=True, # 推荐:使用新的导出路径
)

dynamo=True 的意思是:走基于 torch.export 的新导出路径(而不是旧的 TorchScript 路径)。这是现在的官方推荐方式。

验证导出是否成功

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import onnxruntime as ort
import numpy as np

# 加载 onnx 模型
sess = ort.InferenceSession("model.onnx")

# 用 numpy 输入运行
x_np = np.random.randn(2, 8).astype(np.float32)
outputs = sess.run(None, {"input": x_np})
print(outputs[0].shape) # (2, 4)

# 和 PyTorch 输出对比
with torch.inference_mode():
pt_output = model(torch.tensor(x_np)).numpy()

max_err = abs(outputs[0] - pt_output).max()
print(f"最大误差: {max_err:.6f}") # 应该很小

11. 为什么 ONNX 导出也有动态 shape 问题

torch.export 一样,ONNX 导出默认会把 example inputs 的 shape 固定死。

解决方法是在导出时声明 dynamic_shapes(新 API):

1
2
3
4
5
6
7
8
9
10
11
12
13
from torch.export import Dim

batch = Dim("batch", min=1, max=64)

torch.onnx.export(
model,
(x,),
"model_dynamic.onnx",
input_names=["input"],
output_names=["output"],
dynamo=True,
dynamic_shapes={"x": {0: batch}}, # 声明 batch 维度是动态的
)

这样导出的 ONNX 模型,就能接受不同 batch size 的输入了。


12. 四种方式对比一览

特性 Eager torch.compile() torch.export ONNX 导出
主要目标 灵活执行 PyTorch 内提速 规范化图捕获 跨运行时部署
需要改代码 几乎不需要 可能需要 可能需要
对动态控制流 完全支持 部分支持 有限支持 有限支持
动态 shape 完全支持 通过 dynamic=True 需显式声明 需显式声明
首次开销 有(编译时间) 有(导出时间) 有(导出时间)
跨框架运行
适合场景 开发调试 在线推理加速 部署前准备 跨平台部署

13. 为什么"eager 能跑"不等于"compile/export 一定成功"

这是非常重要的工程现实。

Eager 的优势正是它的灵活性:任何 Python 写法都能运行,控制流完全自由,“边跑边看”。

但 compile/export 要求"提前理解执行逻辑",所以以下写法都可能造成问题:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 反例: 依赖 Tensor 数值的 Python 控制流
if x.sum() > 0: # 需要运行才知道,无法提前图化
...

# 反例: 调用 .item() 触发同步
n = len_tensor.item() # 把 GPU 值取回 Python,打断图捕获

# 反例: 对输入 shape 的隐式假设
x = x.view(x.shape[0], -1, 8) # 假设最后维能被 8 整除,export 时要检查

# 反例: 不规则的容器操作
outputs = {}
for i in range(n_layers):
outputs[f"layer_{i}"] = ... # 动态 key 的字典,难以图化

结论

  • Eager 适合写法自由
  • Compile/Export 适合"更规范、图化友好"的写法
  • 两者有时需要在代码上做一些调整才能兼容

14. example_inputs 的重要性

torch.compiletorch.export 与 ONNX 导出均依赖 example_inputs(或等价占位)以推断张量形状、dtype 与静态结构;随意示例可能导致与线上分布不一致的图或导出失败。

它决定了:

  • trace 的执行路径(哪些 if 分支被走到)
  • shape 假设(哪些维度被认为是固定的)
  • dtype 假设(输入是 float32 还是 float16)
  • 哪些 guard 条件会被记录

工程原则

example_inputs 宜接近线上典型 shape/dtype;若线上为动态 batch,须在导出 API 中声明动态维,否则易在导出或下游运行时失败。


15. 常见坑速查

现象 解决方向
compile 当"免费加速" 首次慢,频繁重编译 先 benchmark,加 dynamic=True
graph break 太多 编译收益小 避免数据依赖控制流,用 TORCH_LOGS=graph_breaks 诊断
没有声明 dynamic shape 导出后只能接受固定 shape Dim 显式声明哪些维度是动态的
example_inputs 不代表真实场景 导出图对真实输入失效 让 example_inputs 贴近真实分布
eager 能跑就以为 export 没问题 导出时报控制流相关错误 检查 forward 里的 Python 控制流写法
用旧 TorchScript 路径 ONNX 导出 缺少动态 shape 等新特性 改用 dynamo=True

16. 四种技术的工程选择逻辑

想在 PyTorch 内部快一点

→ 用 torch.compile(model),比较 compile 前后的 benchmark 结果

想得到规范化图,准备后续分析或部署工具链

→ 用 torch.export.export(model, example_inputs)

想把模型给 ONNX Runtime 或 TensorRT 等工具

→ 用 torch.onnx.export(..., dynamo=True)

想最快速度评估模型效果,或者做动态结构实验

→ 直接用 eager,compile/export 留到真正需要部署时再考虑


17. 面试常见问题

Q:torch.compile() 是做什么的?
A:它用 TorchDynamo 分析 eager 代码,把其中能图化的部分提取成计算图,再交给编译 backend(如 TorchInductor)做 kernel fusion 等优化,缓存编译结果供后续复用。目标是在保留 eager 使用方式的前提下提升执行效率。

Q:什么是 graph break?
A:当 compile 遇到无法图化的 Python 写法(如依赖 Tensor 值的 if 分支、调用 .item() 等),就会在那个位置断图,前后分别编译成独立的小图,断点处回退 eager 执行。graph break 太多会导致 compile 收益降低。

Q:guard failure 是什么?
A:compile 编译后会缓存结果,并记录该结果在什么条件下有效(guard)。如果后续调用时条件不满足(比如 shape 变了),缓存失效,需要重新编译。

Q:torch.exporttorch.compile() 的区别是什么?
A:compile 更偏运行时优化,在 PyTorch 内部提速;export 更偏生成一个严格的、只包含 Tensor 计算的规范图表示,去掉 Python 运行时依赖,适合后续部署工具链使用。

Q:为什么动态 shape 在 export 里是个大问题?
A:export 默认把 example inputs 的 shape 固定死,生成的图只对这一个 shape 有效。真实推理里 batch size 和序列长度通常是动态的,不声明动态 shape 就会导致导出的图对真实输入失效。需要用 Dim 显式声明哪些维度是动态的。

Q:为什么 ONNX 导出现在推荐 dynamo=True
A:因为新的导出路径(dynamo=True)底层走 torch.export 的逻辑,支持更好的动态 shape 表达、更规范的图表示,是官方目前的推荐方式。旧的 TorchScript 路径功能更受限。


18. 思考题

练习 1:观察 torch.compile() 的首次编译开销

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
import time

device = "cuda" if torch.cuda.is_available() else "cpu"
model = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 512)).to(device)
x = torch.randn(64, 512, device=device)

compiled = torch.compile(model)
compiled.eval()

times = []
with torch.inference_mode():
for i in range(5):
if device == "cuda":
torch.cuda.synchronize()
t0 = time.time()
y = compiled(x)
if device == "cuda":
torch.cuda.synchronize()
t1 = time.time()
times.append((t1 - t0) * 1000)
print(f"第 {i+1} 次: {times[-1]:.1f} ms")

思考:第 1 次和第 2 次之后的时间差多少?为什么?

练习 2:触发 graph break

写一个包含数据依赖分支的 forward:

1
2
3
4
5
6
7
8
9
10
11
12
class BranchModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(8, 8)

def forward(self, x):
if x.sum() > 0: # ← 这里会触发 graph break
return self.fc(x)
return x

model = BranchModel()
compiled = torch.compile(model)

然后运行:

1
TORCH_LOGS=graph_breaks python your_script.py

思考:日志里提示 graph break 的原因是什么?怎样改写才能避免?

练习 3:最小 torch.export

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.nn as nn

class M(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(8, 16)
self.fc2 = nn.Linear(16, 4)

def forward(self, x):
return self.fc2(torch.relu(self.fc1(x)))

model = M().eval()
ep = torch.export.export(model, (torch.randn(2, 8),))

print(type(ep)) # ExportedProgram
print(ep.graph) # 规范化的计算图

思考:graph 里 op 的名字是什么风格(aten.mm.default 这类)?和 eager 时调用 nn.Linear 有什么区别?

练习 4:最小 ONNX 导出和验证

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn

model = nn.Linear(8, 4).eval()
x = torch.randn(2, 8)

torch.onnx.export(
model, (x,), "test_model.onnx",
input_names=["x"], output_names=["y"],
dynamo=True
)

# 验证
import onnxruntime as ort
import numpy as np

sess = ort.InferenceSession("test_model.onnx")
x_np = x.numpy()
onnx_out = sess.run(None, {"x": x_np})[0]

with torch.inference_mode():
pt_out = model(x).numpy()

print(f"最大误差: {abs(onnx_out - pt_out).max():.6f}")

思考:如果现在用 batch=4 运行 onnx_runtime,会报错吗?怎么修改导出以支持动态 batch?

练习 5:声明动态 shape

在练习 4 的基础上,加入动态 batch 声明:

1
2
3
4
5
6
7
8
9
10
from torch.export import Dim

batch_dim = Dim("batch", min=1, max=64)

torch.onnx.export(
model, (x,), "test_model_dynamic.onnx",
input_names=["x"], output_names=["y"],
dynamo=True,
dynamic_shapes={"x": {0: batch_dim}} # 第 0 维是动态的
)

思考:用不同 batch size(如 1、4、64)分别运行 ONNX Runtime,结果正确吗?


19. 本节要点与自检

  • 理解 eager / compile / export / ONNX 四者的角色和层次关系
  • 理解 compile 的加速原理(图化 + kernel fusion),以及 graph break 和 guard failure 意味着什么
  • 理解 export 为什么比 compile 更严格,ExportedProgram 包含什么
  • 知道动态 shape 默认不被支持,如何用 Dim 声明
  • 会写基本的 torch.compiletorch.exporttorch.onnx.export 代码
  • 明白为什么 eager 能跑不等于 compile/export 一定成功
  • 知道 example_inputs 的重要性,不能随意给

20. 小结

torch.compile() 是在 PyTorch 里让相同代码跑更快,torch.export 是把模型变成严格的规范图,ONNX 是把这个图带到更广泛的部署世界——三者是递进的工程链路,不是互相替代的关系。


系列导航