PyTorch推理工程:02 nn.Module、forward 与 state_dict
PyTorch 推理工程(02):nn.Module、forward 与 state_dict
1. 本节定位
前一篇已说明 Tensor 的 dtype、device、shape 与布局。工程代码中通常不直接堆叠裸算子,而是通过 nn.Module 封装参数、buffer 与前向过程,例如:
1 | model = MyModel(...) |
阅读与修改推理代码时,一般需要:根据模块结构定位 forward 路径;区分参数与 buffer;正确加载 state_dict;理解 train() / eval() 对子模块行为的影响(与 autograd 开关配合见下一篇)。
2. 什么是 nn.Module
定义
nn.Module 负责:注册参数与 buffer、实现 forward、提供 to(device)、state_dict / load_state_dict 等设备与序列化接口。可将其理解为带命名层次与生命周期管理的可调用计算单元。
它能保存:
- 参数(Parameter):需要学习的权重(比如矩阵 W、偏置 b)
- 缓冲区(Buffer):影响计算但不需要学习的状态(比如统计量)
- 子模块(Submodule):其他
nn.Module,可以一层层嵌套
3. 一个最小的 Module
先看最简单的例子,感受结构:
1 | import torch |
三个核心规则,记住就行:
| 规则 | 说明 |
|---|---|
继承 nn.Module |
自定义模块必须这样做 |
__init__() 里定义状态 |
子模块、参数、buffer 都在这里注册 |
forward() 里定义计算 |
输入如何一步步变成输出 |
4. 为什么写 model(x),而不是 model.forward(x)
这是初学者非常常见的困惑,这里解释清楚。
直接调用 forward() 有什么问题
调用 model(x) 时实际进入 Module.__call__(x):其在调用用户实现的 forward() 前后插入钩子与通用逻辑。
- 调用注册的 forward hooks(比如调试工具、profiler 插件都依赖这个)
- 处理 forward pre-hooks
- 处理某些 梯度追踪相关的状态
若直接调用 model.forward(x),上述 __call__ 路径与已注册 hook 不会执行;依赖 hook 的工具(如 torch.profiler)可能无法按预期记录。
结论:
永远写
model(x),不要把model.forward(x)当成常规写法。
5. 带参数的模块:参数(Parameter)是什么
先看一个有参数的模块
1 | import torch |
nn.Parameter 和普通 Tensor 有什么区别
关键区别在于注册:
1 | # 反例: 这是普通 Tensor,模块不认识它 |
使用 nn.Parameter 包装后,张量被注册为模块参数,因而:
- 它会被
model.parameters()遍历到 - 它会出现在
state_dict()里 - 调用
model.to(device)时会跟着移动 - 训练时优化器会更新它
怎么查看模块有哪些参数
1 | m = MyLinear(4, 3) |
6. 子模块:模块里套模块
真实模型几乎不会只有一层。PyTorch 的强大之处在于:模块可以包含子模块,子模块还可以继续嵌套。
1 | import torch |
遍历子模块
PyTorch 提供了几个方法来遍历模块树:
1 | # 只看直接子模块 |
子模块的参数会自动被追踪
1 | for name, param in model.named_parameters(): |
键名格式为 子模块名.参数名,即与模块层次一致的层级式命名,state_dict 中沿用同一规则。
model.to(device)会沿子模块树递归迁移参数与 buffer,故无需逐张量手写迁移。
7. forward() 到底在干什么
forward() 定义从输入到输出的计算步骤。
forward() 可以写任意逻辑
它不要求是简单的串行连接,可以:
1 | class FlexibleModel(nn.Module): |
这里:
- 接收了第二个参数
use_residual(forward 可以有多个入参) - 有条件分支(
if/else) - 有多个子模块参与计算
推理代码里阅读 forward() 的方式
阅读陌生模型时可沿 forward() 关注:
- 输入是什么 →
def forward(self, x, attention_mask=None, ...): - 经过哪些层 →
x = self.norm(x)/x = self.attn(x, mask)/ … - shape 在哪里变 → 每次
reshape/transpose/view的地方 - cache / mask / position 在哪里参与 → 推理优化的关键点
- 输出是什么 →
return logits/return hidden_states, past_kv/ …
8. nn.Sequential:方便但有局限
nn.Sequential 可以快速串起若干层:
1 | model = nn.Sequential( |
简洁,但有它的边界:
| 适合 Sequential | 不适合 Sequential |
|---|---|
| 单输入 / 单输出 | 多输入 / 多输出 |
| 纯串行,无分支 | 残差连接、门控分支 |
| 快速实验原型 | 推理代码里有 mask / cache |
| 简单 demo | attention / LLM / 复杂 CV |
Sequential仅适用于线性堆叠;多数模型需自定义forward()以表达分支与中间状态。
9. 参数(Parameter)和缓冲区(Buffer)的区别
这是推理岗位容易踩坑的地方。
类比理解
把模块想象成一个计算器:
-
参数(Parameter):计算器里的算法规则(每次训练都会更新)
例如:线性层的权重矩阵 W、偏置 b -
缓冲区(Buffer):计算器里的"记忆存储"(记录运行时的统计量,不直接被优化器更新)
例如:BatchNorm 层里记录的"这批数据的均值和方差的历史"
最典型的例子:BatchNorm 的 running mean
BatchNorm 层训练时会持续更新两个统计量:
running_mean:这不是通过反向传播更新的,而是每个 batch 过来后直接计算更新running_var:同上
这两个就是缓冲区(Buffer)。
1 | import torch |
输出:
1 | 参数: weight torch.Size([4]) ← gamma(可学习的缩放参数) |
如何在自定义模块里注册 buffer
1 | class MyModule(nn.Module): |
三种状态的对比
| 类型 | 能学习 | 在 state_dict 里 | 随 .to() 移动 |
典型例子 |
|---|---|---|---|---|
nn.Parameter |
是 | 是 | 是 | Linear.weight, Embedding.weight |
| Persistent Buffer | 否 | 是 | 是 | BN.running_mean, 预计算位置编码 |
| 普通成员 Tensor | 否 | 否 | 否 | 临时中间变量(不推荐存成属性) |
为什么推理时要懂 buffer?
因为推理常见的坑很多是这一类:
- 模型搬到 GPU 后,某个 buffer 还留在 CPU → 运算设备不一致报错
- 加载权重后 BatchNorm 行为异常 → running_mean 没有正确恢复
state_dict对不上 → buffer 名字或形状不匹配
10. state_dict:模型状态的标准导出格式
什么是 state_dict
state_dict 就是一个有序字典,把模块里所有参数和 buffer 的名字(key)和值(Tensor)对应起来。
1 | import torch |
多层模型的 state_dict 长什么样
1 | class MLP(nn.Module): |
注意 key 的格式:子模块名.参数名。如果有更深的嵌套,key 会更长,比如 encoder.layer.0.fc.weight。
state_dict的键名与模块层次一一对应;加载报错时宜先比对键名与形状。
11. 正确的保存与加载方式
保存
1 | torch.save(model.state_dict(), "model.pt") |
加载
1 | # 第一步:先构造一个同结构的模型 |
注意 map_location="cpu":如果保存时在 GPU 上,加载时可以先指定加载到 CPU,之后再 .to(device) 按需移动。
为什么不直接 torch.save(model) 保存整个模型
很多人这样写:
1 | torch.save(model, "model_full.pt") # 不推荐 |
这样要求加载时 Python 环境、模块类定义路径完全一致,换了机器或者重构了代码,很容易加载失败。
推荐方式:保存 state_dict,加载时在代码里重新定义模型结构,再填入权重。这样模型结构和权重解耦,工程上更稳。
12. load_state_dict() 报错时怎么排查
这是实战里极常见的场景。报错信息通常有两种:
情况 1:多余的 key(权重文件里有,但模型里没有)
1 | RuntimeError: Unexpected key(s) in state_dict: "old_head.weight", "old_head.bias" |
说明:权重文件是旧版本模型训练的,现在的代码里把那层改名了或删了。
情况 2:缺少的 key(模型需要,但权重文件没有)
1 | RuntimeError: Missing key(s) in state_dict: "classifier.weight", "classifier.bias" |
说明:模型代码里新增了某层,但权重文件是旧版本的,没有这一层的权重。
诊断步骤
1 | # 对比两边的 key |
常见原因
| 现象 | 可能原因 |
|---|---|
key 多了 module. 前缀 |
权重是 DataParallel 训练的,有额外封装层 |
| key 名字不一样 | 代码里子模块改名了(fc → head) |
| shape 不匹配 | 输出维度改了(比如分类数从 1000 → 10) |
| buffer 丢失 | 权重文件保存时用的是旧版本,buffer 名字变了 |
13. train() 与 eval():行为模式切换
这一节很多人会踩坑,搞清楚后面学混合精度和部署会顺很多。
先弄清楚一个误解
model.eval() 不是关闭梯度, 它只是告诉模型:“现在是推理时间,不是训练时间。”
而梯度的关闭,是靠 torch.no_grad() 或 torch.inference_mode() 完成的(见下一篇)。
这两件事独立,必须都做:
1 | model.eval() # 切换行为模式 |
eval() 具体影响哪些模块
并不是所有模块都受影响,只有那些训练和推理行为本来就不同的模块:
Dropout
1 | dropout = nn.Dropout(p=0.5) # 训练时随机丢弃 50% 的值 |
输出示例:
1 | tensor([[0., 0., 2., 2., 0., 2., 0., 0., 2., 0.]]) ← train(),有随机 0 |
BatchNorm
- 训练时:用当前 batch 的均值和方差做归一化,同时更新
running_mean/running_var - 推理时:用保存下来的
running_mean/running_var做归一化,不更新
若推理前未调用 model.eval(),BatchNorm 等层可能仍使用 batch 统计量,导致同输入输出不稳定或与训练分布不一致。
标准推理模板
1 | model.eval() # 1. 切行为模式 |
这三行缺一不可。
14. model.to():整体迁移设备和精度
Tensor 的 .to() 与 Module 的 .to() 语义一致,后者会递归遍历子模块树:
1 | device = "cuda" if torch.cuda.is_available() else "cpu" |
这一句话会把:
- 所有直接参数
- 所有 buffer
- 所有子模块的参数和 buffer
全部移动到指定设备。
也可以转精度:
1 | model = model.to(dtype=torch.float16) # 把所有参数转成 float16 |
实际工程中常用
.to()统一迁移参数与 buffer,而无需手写遍历各子模块参数。
15. 一个完整的推理代码骨架
下面用一个完整例子,把这一部分的所有核心概念串起来。
1 | import torch |
自检表
| 问题 | 要点 |
|---|---|
为什么 fc1、relu 会成为子模块? |
它们作为属性赋值在 __init__ 里,PyTorch 自动注册 |
为什么 model(x) 会走 forward()? |
调用模块会经过 __call__,它内部触发 forward() |
model.eval() 改变了什么? |
Dropout 不再丢弃,BatchNorm 用 running stats |
model.to(device) 移动了什么? |
所有参数和 buffer,递归处理整个模块树 |
为什么 key 是 fc1.weight 这种格式? |
子模块名 . 参数名,层级拼接 |
为什么要 model2 = SimpleClassifier(...) 再 load_state_dict? |
PyTorch 推荐:结构和权重解耦,保存和恢复更稳定 |
16. 阅读推理代码的标准顺序
拿到任何一个 PyTorch 推理项目,按这个顺序看:
第一步:找模型定义
1 | class XXX(nn.Module): |
第二步:看 __init__()
- 有哪些子模块?
- 有哪些参数?
- 有哪些 buffer?
- 有没有 KV cache、位置编码相关状态?
第三步:顺着 forward() 走一遍
- 输入是什么?多少个参数?
- 每一步 shape 怎么变?
- mask / position_ids / cache 在哪里出现?
- 输出是什么?
第四步:看权重加载
load_state_dict在哪里?- 有没有
strict=False?(这意味着部分权重不匹配,是刻意的还是 bug?) - 有没有 key 重命名操作?
第五步:看推理入口
model.eval()有没有做?torch.no_grad()或torch.inference_mode()有没有做?- 输入有没有
.to(device)?
17. 常见误区
误区 1:普通 Tensor 成员属性算参数
1 | self.scale = torch.ones(1) # 反例: 不是参数,model.state_dict() 里没有它 |
误区 2:model.eval() 等于关闭梯度
1 | model.eval() # 反例: 只切换行为模式,梯度还在 |
误区 3:直接 torch.save(model) 就行
工程上不推荐。加载时要求 Python 环境和代码路径完全一致,容易出问题。
误区 4:Sequential 足够用于推理
对于简单的 demo 够用,但真实推理代码(Attention、LLM、多流输入)几乎都需要自定义 forward()。
误区 5:buffer 不重要
实际推理项目里很多奇怪 bug(BatchNorm 行为异常、权重加载 key 不匹配)都和 buffer 有关。
18. 面试常见问题
下列对照可用于自检:
Q:nn.Module 是什么?
A:PyTorch 中组织模型的基础类,能统一管理参数、buffer、子模块和前向计算逻辑。
Q:为什么写 model(x) 而不是 model.forward(x)?
A:前者走 __call__,会触发 hooks 和框架级处理;后者直接调方法,会绕过这些机制。
Q:Parameter 和 Buffer 有什么区别?
A:Parameter 是可学习的状态(优化器会更新),Buffer 是影响计算但不需要学习的状态(比如 BN 的 running mean)。两者都出现在 state_dict 中,都随 .to() 移动设备。
Q:为什么保存 state_dict 而不是整个模型?
A:解耦结构和权重,可移植性更强,是官方推荐方式。
Q:model.eval() 具体影响了什么?
A:关闭了 Dropout 的随机丢弃;让 BatchNorm 使用 running stats 而不是当前 batch 的统计量。不负责关梯度。
19. 思考题
建议自己跑一遍,先猜输出,再验证。
练习 1:写一个最简单的 Module
定义一个 ScaleLayer:初始化时接受一个 float scale,forward() 返回 x * scale。
然后打印 named_parameters(),看有没有参数。
思考:改成用
nn.Parameter存 scale 后,named_parameters()输出有什么变化?
练习 2:子模块观察
1 | class Net(nn.Module): |
打印 model.named_parameters() 的所有 key,观察格式。
思考:将子模块
encoder重命名为layer1后,state_dict键名如何变化?
练习 3:观察 Dropout 在 train/eval 下的差异
1 | model = nn.Sequential(nn.Dropout(p=0.5)) |
思考:train 模式下每次输出一样吗?eval 模式下呢?为什么?
练习 4:保存 / 加载 / 对比 key
自己写一个两层 MLP,保存 state_dict 到文件,再新建同结构模型加载,打印两边的 key 确认一致。
进阶挑战:把第二个模型的第二层 Linear 改成
out_features=10(原来是 4),再load_state_dict,观察报错信息。
练习 5:buffer 注册与检查
写一个 ScaleNorm 模块:
- 注册一个 buffer
running_scale,初始值为torch.ones(1) forward()里返回x * self.running_scale- 打印
named_parameters()、named_buffers()、state_dict()
思考:
running_scale在named_parameters()里能看到吗?在named_buffers()里呢?
20. 本节要点与自检
- 能看懂任意 PyTorch 模型类的基本结构
- 理解
__init__()和forward()各自负责什么 - 区分 Parameter、Buffer、普通成员变量,知道它们的行为差异
- 能用
state_dict()正确保存和加载权重 - 知道
model(x)vsmodel.forward(x)的区别,以及为什么要用前者 - 知道
model.eval()影响哪些层,以及它和梯度关闭的关系 - 能沿着子模块结构读懂简单推理项目
21. 小结
nn.Module的本质,是把"状态"和"前向计算"统一组织起来;推理代码的阅读、改写、保存、加载,几乎都围绕它展开。
