PyTorch 推理工程(02):nn.Module、forward 与 state_dict

1. 本节定位

前一篇已说明 Tensor 的 dtype、device、shape 与布局。工程代码中通常不直接堆叠裸算子,而是通过 nn.Module 封装参数、buffer 与前向过程,例如:

1
2
3
4
5
6
model = MyModel(...)
model.load_state_dict(torch.load("weights.pt"))
model.eval()

with torch.inference_mode():
y = model(x)

阅读与修改推理代码时,一般需要:根据模块结构定位 forward 路径;区分参数与 buffer;正确加载 state_dict;理解 train() / eval() 对子模块行为的影响(与 autograd 开关配合见下一篇)。


2. 什么是 nn.Module

定义

nn.Module 负责:注册参数与 buffer、实现 forward、提供 to(device)state_dict / load_state_dict 等设备与序列化接口。可将其理解为带命名层次与生命周期管理的可调用计算单元。

它能保存:

  • 参数(Parameter):需要学习的权重(比如矩阵 W、偏置 b)
  • 缓冲区(Buffer):影响计算但不需要学习的状态(比如统计量)
  • 子模块(Submodule):其他 nn.Module,可以一层层嵌套

3. 一个最小的 Module

先看最简单的例子,感受结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn

class AddOne(nn.Module):
def __init__(self):
super().__init__() # 必须调用父类初始化

def forward(self, x):
return x + 1 # 定义"输入怎么变成输出"

m = AddOne()
x = torch.tensor([1.0, 2.0, 3.0])
y = m(x) # 调用模块,会触发 forward()

print(y)
# 输出:tensor([2., 3., 4.])

三个核心规则,记住就行:

规则 说明
继承 nn.Module 自定义模块必须这样做
__init__() 里定义状态 子模块、参数、buffer 都在这里注册
forward() 里定义计算 输入如何一步步变成输出

4. 为什么写 model(x),而不是 model.forward(x)

这是初学者非常常见的困惑,这里解释清楚。

直接调用 forward() 有什么问题

调用 model(x) 时实际进入 Module.__call__(x):其在调用用户实现的 forward() 前后插入钩子与通用逻辑。

  • 调用注册的 forward hooks(比如调试工具、profiler 插件都依赖这个)
  • 处理 forward pre-hooks
  • 处理某些 梯度追踪相关的状态

若直接调用 model.forward(x),上述 __call__ 路径与已注册 hook 不会执行;依赖 hook 的工具(如 torch.profiler)可能无法按预期记录。

结论

永远写 model(x),不要把 model.forward(x) 当成常规写法。


5. 带参数的模块:参数(Parameter)是什么

先看一个有参数的模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn as nn

class MyLinear(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
# 用 nn.Parameter 包裹,这才是"模块参数"
self.weight = nn.Parameter(torch.randn(in_features, out_features))
self.bias = nn.Parameter(torch.randn(out_features))

def forward(self, x):
return x @ self.weight + self.bias

m = MyLinear(4, 3)
x = torch.randn(2, 4)
y = m(x)

print(y.shape) # torch.Size([2, 3])

nn.Parameter 和普通 Tensor 有什么区别

关键区别在于注册

1
2
3
4
5
# 反例: 这是普通 Tensor,模块不认识它
self.weight = torch.randn(4, 3)

# 示例: 这是模块参数,PyTorch 会自动追踪它
self.weight = nn.Parameter(torch.randn(4, 3))

使用 nn.Parameter 包装后,张量被注册为模块参数,因而:

  • 它会被 model.parameters() 遍历到
  • 它会出现在 state_dict()
  • 调用 model.to(device) 时会跟着移动
  • 训练时优化器会更新它

怎么查看模块有哪些参数

1
2
3
4
5
6
7
8
m = MyLinear(4, 3)

for name, param in m.named_parameters():
print(name, param.shape)

# 输出:
# weight torch.Size([4, 3])
# bias torch.Size([3])

6. 子模块:模块里套模块

真实模型几乎不会只有一层。PyTorch 的强大之处在于:模块可以包含子模块,子模块还可以继续嵌套

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn

class MLP(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.fc1 = nn.Linear(in_dim, hidden_dim) # 子模块
self.act = nn.ReLU() # 子模块
self.fc2 = nn.Linear(hidden_dim, out_dim) # 子模块

def forward(self, x):
x = self.fc1(x) # (B, in) → (B, hidden)
x = self.act(x) # (B, hidden),逐元素激活
x = self.fc2(x) # (B, hidden) → (B, out)
return x

model = MLP(8, 16, 4)
x = torch.randn(2, 8)
y = model(x)

print(y.shape) # torch.Size([2, 4])

遍历子模块

PyTorch 提供了几个方法来遍历模块树:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 只看直接子模块
for name, module in model.named_children():
print(name, "→", module)

# 输出:
# fc1 → Linear(in_features=8, out_features=16, bias=True)
# act → ReLU()
# fc2 → Linear(in_features=16, out_features=4, bias=True)

# 递归遍历所有模块(包括自身和所有子孙模块)
for name, module in model.named_modules():
print(name or "(root)", "→", type(module).__name__)

# 输出:
# (root) → MLP
# fc1 → Linear
# act → ReLU
# fc2 → Linear

子模块的参数会自动被追踪

1
2
3
4
5
6
7
8
for name, param in model.named_parameters():
print(name, param.shape)

# 输出:
# fc1.weight torch.Size([16, 8])
# fc1.bias torch.Size([16])
# fc2.weight torch.Size([4, 16])
# fc2.bias torch.Size([4])

键名格式为 子模块名.参数名,即与模块层次一致的层级式命名,state_dict 中沿用同一规则。

model.to(device) 会沿子模块树递归迁移参数与 buffer,故无需逐张量手写迁移。


7. forward() 到底在干什么

forward() 定义从输入到输出的计算步骤。

forward() 可以写任意逻辑

它不要求是简单的串行连接,可以:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class FlexibleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(8, 8)
self.gate = nn.Linear(8, 1)

def forward(self, x, use_residual=True):
h = self.fc(x)

if use_residual:
h = h + x # 残差连接(跳跃连接)

gate = torch.sigmoid(self.gate(x)) # 门控机制
return h * gate # 逐元素乘

这里:

  • 接收了第二个参数 use_residual(forward 可以有多个入参)
  • 有条件分支(if/else
  • 有多个子模块参与计算

推理代码里阅读 forward() 的方式

阅读陌生模型时可沿 forward() 关注:

  1. 输入是什么def forward(self, x, attention_mask=None, ...):
  2. 经过哪些层x = self.norm(x) / x = self.attn(x, mask) / …
  3. shape 在哪里变 → 每次 reshape / transpose / view 的地方
  4. cache / mask / position 在哪里参与 → 推理优化的关键点
  5. 输出是什么return logits / return hidden_states, past_kv / …

8. nn.Sequential:方便但有局限

nn.Sequential 可以快速串起若干层:

1
2
3
4
5
6
7
8
9
model = nn.Sequential(
nn.Linear(8, 16),
nn.ReLU(),
nn.Linear(16, 4)
)

x = torch.randn(2, 8)
y = model(x)
print(y.shape) # torch.Size([2, 4])

简洁,但有它的边界:

适合 Sequential 不适合 Sequential
单输入 / 单输出 多输入 / 多输出
纯串行,无分支 残差连接、门控分支
快速实验原型 推理代码里有 mask / cache
简单 demo attention / LLM / 复杂 CV

Sequential 仅适用于线性堆叠;多数模型需自定义 forward() 以表达分支与中间状态。


9. 参数(Parameter)和缓冲区(Buffer)的区别

这是推理岗位容易踩坑的地方。

类比理解

把模块想象成一个计算器

  • 参数(Parameter):计算器里的算法规则(每次训练都会更新)
    例如:线性层的权重矩阵 W、偏置 b

  • 缓冲区(Buffer):计算器里的"记忆存储"(记录运行时的统计量,不直接被优化器更新)
    例如:BatchNorm 层里记录的"这批数据的均值和方差的历史"

最典型的例子:BatchNorm 的 running mean

BatchNorm 层训练时会持续更新两个统计量:

  • running_mean:这不是通过反向传播更新的,而是每个 batch 过来后直接计算更新
  • running_var:同上

这两个就是缓冲区(Buffer)

1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.nn as nn

bn = nn.BatchNorm1d(4) # 对 4 维特征做归一化

# 看参数(可学习的 gamma 和 beta)
for name, param in bn.named_parameters():
print("参数:", name, param.shape)

# 看 buffer(running_mean、running_var 等)
for name, buf in bn.named_buffers():
print("buffer:", name, buf.shape)

输出:

1
2
3
4
5
参数: weight torch.Size([4])   ← gamma(可学习的缩放参数)
参数: bias torch.Size([4]) ← beta(可学习的偏移参数)
buffer: running_mean torch.Size([4]) ← 不可学习的统计量
buffer: running_var torch.Size([4]) ← 不可学习的统计量
buffer: num_batches_tracked torch.Size([])

如何在自定义模块里注册 buffer

1
2
3
4
5
6
7
8
9
10
11
12
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# register_buffer 注册一个 buffer,名叫 "scale"
self.register_buffer("scale", torch.ones(1))

def forward(self, x):
return x * self.scale # 用 buffer 参与计算

m = MyModule()
print(m.scale) # tensor([1.])
print(m.state_dict()) # OrderedDict([('scale', tensor([1.]))])

三种状态的对比

类型 能学习 在 state_dict 里 .to() 移动 典型例子
nn.Parameter Linear.weight, Embedding.weight
Persistent Buffer BN.running_mean, 预计算位置编码
普通成员 Tensor 临时中间变量(不推荐存成属性)

为什么推理时要懂 buffer?
因为推理常见的坑很多是这一类:

  • 模型搬到 GPU 后,某个 buffer 还留在 CPU → 运算设备不一致报错
  • 加载权重后 BatchNorm 行为异常 → running_mean 没有正确恢复
  • state_dict 对不上 → buffer 名字或形状不匹配

10. state_dict:模型状态的标准导出格式

什么是 state_dict

state_dict 就是一个有序字典,把模块里所有参数和 buffer 的名字(key)和值(Tensor)对应起来。

1
2
3
4
5
6
7
8
9
10
import torch
import torch.nn as nn

model = nn.Linear(4, 3)
sd = model.state_dict()

print(type(sd)) # <class 'collections.OrderedDict'>
print(sd.keys()) # odict_keys(['weight', 'bias'])
print(sd['weight'].shape) # torch.Size([3, 4])
print(sd['bias'].shape) # torch.Size([3])

多层模型的 state_dict 长什么样

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(8, 16)
self.fc2 = nn.Linear(16, 4)

def forward(self, x):
return self.fc2(torch.relu(self.fc1(x)))

model = MLP()
sd = model.state_dict()

for key, val in sd.items():
print(key, val.shape)

# 输出:
# fc1.weight torch.Size([16, 8])
# fc1.bias torch.Size([16])
# fc2.weight torch.Size([4, 16])
# fc2.bias torch.Size([4])

注意 key 的格式:子模块名.参数名。如果有更深的嵌套,key 会更长,比如 encoder.layer.0.fc.weight

state_dict 的键名与模块层次一一对应;加载报错时宜先比对键名与形状。


11. 正确的保存与加载方式

保存

1
torch.save(model.state_dict(), "model.pt")

加载

1
2
3
4
5
6
7
8
# 第一步:先构造一个同结构的模型
model = MLP()

# 第二步:加载权重字典
state_dict = torch.load("model.pt", map_location="cpu")

# 第三步:把权重填入模型
model.load_state_dict(state_dict)

注意 map_location="cpu":如果保存时在 GPU 上,加载时可以先指定加载到 CPU,之后再 .to(device) 按需移动。

为什么不直接 torch.save(model) 保存整个模型

很多人这样写:

1
torch.save(model, "model_full.pt")   # 不推荐

这样要求加载时 Python 环境、模块类定义路径完全一致,换了机器或者重构了代码,很容易加载失败。

推荐方式:保存 state_dict,加载时在代码里重新定义模型结构,再填入权重。这样模型结构和权重解耦,工程上更稳。


12. load_state_dict() 报错时怎么排查

这是实战里极常见的场景。报错信息通常有两种:

情况 1:多余的 key(权重文件里有,但模型里没有)

1
RuntimeError: Unexpected key(s) in state_dict: "old_head.weight", "old_head.bias"

说明:权重文件是旧版本模型训练的,现在的代码里把那层改名了或删了。

情况 2:缺少的 key(模型需要,但权重文件没有)

1
RuntimeError: Missing key(s) in state_dict: "classifier.weight", "classifier.bias"

说明:模型代码里新增了某层,但权重文件是旧版本的,没有这一层的权重。

诊断步骤

1
2
3
4
5
6
# 对比两边的 key
model_keys = set(model.state_dict().keys())
file_keys = set(torch.load("model.pt").keys())

print("模型有但文件没有:", model_keys - file_keys)
print("文件有但模型没有:", file_keys - model_keys)

常见原因

现象 可能原因
key 多了 module. 前缀 权重是 DataParallel 训练的,有额外封装层
key 名字不一样 代码里子模块改名了(fchead
shape 不匹配 输出维度改了(比如分类数从 1000 → 10)
buffer 丢失 权重文件保存时用的是旧版本,buffer 名字变了

13. train()eval():行为模式切换

这一节很多人会踩坑,搞清楚后面学混合精度和部署会顺很多。

先弄清楚一个误解

model.eval() 不是关闭梯度, 它只是告诉模型:“现在是推理时间,不是训练时间。”

而梯度的关闭,是靠 torch.no_grad()torch.inference_mode() 完成的(见下一篇)。

这两件事独立,必须都做:

1
2
3
model.eval()                        # 切换行为模式
with torch.inference_mode(): # 关闭梯度
y = model(x)

eval() 具体影响哪些模块

并不是所有模块都受影响,只有那些训练和推理行为本来就不同的模块

Dropout

1
2
3
4
5
6
7
8
9
dropout = nn.Dropout(p=0.5)   # 训练时随机丢弃 50% 的值

x = torch.ones(1, 10)

dropout.train()
print(dropout(x)) # 某些位置会变成 0(随机的)

dropout.eval()
print(dropout(x)) # 全部是 1.0,因为推理时不丢弃

输出示例:

1
2
tensor([[0., 0., 2., 2., 0., 2., 0., 0., 2., 0.]])   ← train(),有随机 0
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]) ← eval(),全部保留

BatchNorm

  • 训练时:用当前 batch 的均值和方差做归一化,同时更新 running_mean / running_var
  • 推理时:用保存下来的 running_mean / running_var 做归一化,不更新

若推理前未调用 model.eval()BatchNorm 等层可能仍使用 batch 统计量,导致同输入输出不稳定或与训练分布不一致。

标准推理模板

1
2
3
4
model.eval()                       # 1. 切行为模式

with torch.inference_mode(): # 2. 关梯度
y = model(x) # 3. 推理

这三行缺一不可。


14. model.to():整体迁移设备和精度

Tensor 的 .to() 与 Module 的 .to() 语义一致,后者会递归遍历子模块树:

1
2
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

这一句话会把:

  • 所有直接参数
  • 所有 buffer
  • 所有子模块的参数和 buffer

全部移动到指定设备。

也可以转精度:

1
model = model.to(dtype=torch.float16)   # 把所有参数转成 float16

实际工程中常用 .to() 统一迁移参数与 buffer,而无需手写遍历各子模块参数。


15. 一个完整的推理代码骨架

下面用一个完整例子,把这一部分的所有核心概念串起来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn

# ① 定义模型
class SimpleClassifier(nn.Module):
def __init__(self, in_dim, hidden_dim, num_classes):
super().__init__()
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.relu = nn.ReLU()
self.drop = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(hidden_dim, num_classes)

def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.drop(x) # ← eval() 下不起作用,train() 下随机丢弃
x = self.fc2(x)
return x

# ② 初始化并迁移设备
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SimpleClassifier(in_dim=8, hidden_dim=16, num_classes=3)
model = model.to(device)

# ③ 保存权重
torch.save(model.state_dict(), "/tmp/classifier.pt")

# ④ 重新加载(模拟部署场景)
model2 = SimpleClassifier(in_dim=8, hidden_dim=16, num_classes=3)
state_dict = torch.load("/tmp/classifier.pt", map_location="cpu")
model2.load_state_dict(state_dict)
model2 = model2.to(device)

# ⑤ 推理
model2.eval()

x = torch.randn(4, 8).to(device)

with torch.inference_mode():
y = model2(x)

print("output shape:", y.shape) # torch.Size([4, 3])
print("output device:", y.device) # cuda:0 (或 cpu)

# ⑥ 查看 state_dict 结构
for key, val in model2.state_dict().items():
print(f"{key:20s}{val.shape}")

# 输出:
# fc1.weight → torch.Size([16, 8])
# fc1.bias → torch.Size([16])
# fc2.weight → torch.Size([3, 16])
# fc2.bias → torch.Size([3])

自检表

问题 要点
为什么 fc1relu 会成为子模块? 它们作为属性赋值在 __init__ 里,PyTorch 自动注册
为什么 model(x) 会走 forward() 调用模块会经过 __call__,它内部触发 forward()
model.eval() 改变了什么? Dropout 不再丢弃,BatchNorm 用 running stats
model.to(device) 移动了什么? 所有参数和 buffer,递归处理整个模块树
为什么 key 是 fc1.weight 这种格式? 子模块名 . 参数名,层级拼接
为什么要 model2 = SimpleClassifier(...)load_state_dict PyTorch 推荐:结构和权重解耦,保存和恢复更稳定

16. 阅读推理代码的标准顺序

拿到任何一个 PyTorch 推理项目,按这个顺序看:

第一步:找模型定义

1
2
class XXX(nn.Module):
...

第二步:看 __init__()

  • 有哪些子模块?
  • 有哪些参数?
  • 有哪些 buffer?
  • 有没有 KV cache、位置编码相关状态?

第三步:顺着 forward() 走一遍

  • 输入是什么?多少个参数?
  • 每一步 shape 怎么变?
  • mask / position_ids / cache 在哪里出现?
  • 输出是什么?

第四步:看权重加载

  • load_state_dict 在哪里?
  • 有没有 strict=False?(这意味着部分权重不匹配,是刻意的还是 bug?)
  • 有没有 key 重命名操作?

第五步:看推理入口

  • model.eval() 有没有做?
  • torch.no_grad()torch.inference_mode() 有没有做?
  • 输入有没有 .to(device)

17. 常见误区

误区 1:普通 Tensor 成员属性算参数

1
2
self.scale = torch.ones(1)      # 反例: 不是参数,model.state_dict() 里没有它
self.scale = nn.Parameter(torch.ones(1)) # 示例: 这才是

误区 2:model.eval() 等于关闭梯度

1
2
3
4
5
model.eval()   # 反例: 只切换行为模式,梯度还在

model.eval()
with torch.inference_mode(): # 示例: 行为模式 + 梯度都处理了
y = model(x)

误区 3:直接 torch.save(model) 就行

工程上不推荐。加载时要求 Python 环境和代码路径完全一致,容易出问题。

误区 4:Sequential 足够用于推理

对于简单的 demo 够用,但真实推理代码(Attention、LLM、多流输入)几乎都需要自定义 forward()

误区 5:buffer 不重要

实际推理项目里很多奇怪 bug(BatchNorm 行为异常、权重加载 key 不匹配)都和 buffer 有关。


18. 面试常见问题

下列对照可用于自检:

Q:nn.Module 是什么?
A:PyTorch 中组织模型的基础类,能统一管理参数、buffer、子模块和前向计算逻辑。

Q:为什么写 model(x) 而不是 model.forward(x)
A:前者走 __call__,会触发 hooks 和框架级处理;后者直接调方法,会绕过这些机制。

Q:Parameter 和 Buffer 有什么区别?
A:Parameter 是可学习的状态(优化器会更新),Buffer 是影响计算但不需要学习的状态(比如 BN 的 running mean)。两者都出现在 state_dict 中,都随 .to() 移动设备。

Q:为什么保存 state_dict 而不是整个模型?
A:解耦结构和权重,可移植性更强,是官方推荐方式。

Q:model.eval() 具体影响了什么?
A:关闭了 Dropout 的随机丢弃;让 BatchNorm 使用 running stats 而不是当前 batch 的统计量。不负责关梯度。


19. 思考题

建议自己跑一遍,先猜输出,再验证。

练习 1:写一个最简单的 Module

定义一个 ScaleLayer:初始化时接受一个 float scaleforward() 返回 x * scale
然后打印 named_parameters(),看有没有参数。

思考:改成用 nn.Parameter 存 scale 后,named_parameters() 输出有什么变化?

练习 2:子模块观察

1
2
3
4
5
6
7
8
9
class Net(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Linear(8, 16)
self.decoder = nn.Linear(16, 4)
def forward(self, x):
return self.decoder(torch.relu(self.encoder(x)))

model = Net()

打印 model.named_parameters() 的所有 key,观察格式。

思考:将子模块 encoder 重命名为 layer1 后,state_dict 键名如何变化?

练习 3:观察 Dropout 在 train/eval 下的差异

1
2
3
4
5
6
7
8
9
10
model = nn.Sequential(nn.Dropout(p=0.5))
x = torch.ones(1, 10)

model.train()
for _ in range(3):
print(model(x))

model.eval()
for _ in range(3):
print(model(x))

思考:train 模式下每次输出一样吗?eval 模式下呢?为什么?

练习 4:保存 / 加载 / 对比 key

自己写一个两层 MLP,保存 state_dict 到文件,再新建同结构模型加载,打印两边的 key 确认一致。

进阶挑战:把第二个模型的第二层 Linear 改成 out_features=10(原来是 4),再 load_state_dict,观察报错信息。

练习 5:buffer 注册与检查

写一个 ScaleNorm 模块:

  • 注册一个 buffer running_scale,初始值为 torch.ones(1)
  • forward() 里返回 x * self.running_scale
  • 打印 named_parameters()named_buffers()state_dict()

思考:running_scalenamed_parameters() 里能看到吗?在 named_buffers() 里呢?


20. 本节要点与自检

  • 能看懂任意 PyTorch 模型类的基本结构
  • 理解 __init__()forward() 各自负责什么
  • 区分 Parameter、Buffer、普通成员变量,知道它们的行为差异
  • 能用 state_dict() 正确保存和加载权重
  • 知道 model(x) vs model.forward(x) 的区别,以及为什么要用前者
  • 知道 model.eval() 影响哪些层,以及它和梯度关闭的关系
  • 能沿着子模块结构读懂简单推理项目

21. 小结

nn.Module 的本质,是把"状态"和"前向计算"统一组织起来;推理代码的阅读、改写、保存、加载,几乎都围绕它展开。


系列导航