PyTorch推理工程:01 Tensor、dtype、device 与推理底层基础
PyTorch 推理工程(01):Tensor、dtype、device 与推理底层基础
1. 本节定位
在实现层面,推理路径可描述为 Tensor 上的算子序列:输入经若干运算得到输出。各步骤均涉及 device、dtype、shape 以及内存布局(含 stride、是否 contiguous)。后续关于混合精度、CUDA 与图优化的讨论均依赖上述属性。
下列问题应能依据张量语义作答,否则容易在现象与根因之间脱节:
- 张量位于 CPU 还是 GPU?
- dtype 取 float32 / float16 / bfloat16 的依据与取舍;
view报错、transpose后性能变化等常见情况的语义来源。
2. 什么是 Tensor
由标量、向量到张量
若已熟悉下列对象,可将张量视为其推广:
| 名称 | 描述 | 例子 |
|---|---|---|
| 标量 | 一个数 | 3.14 |
| 向量 | 一维数组 | [1, 2, 3] |
| 矩阵 | 二维数组 | [[1,2],[3,4]] |
张量为上述对象的推广:维度任意。
- 0 维 Tensor = 标量(一个数)
- 1 维 Tensor = 向量
- 2 维 Tensor = 矩阵
- 3 维及以上 = 更高维数据
Tensor 在深度学习里长什么样
典型 shape 约定示例:
1 | 一张灰度图:(H, W) |
其中常用字母含义:
| 字母 | 含义 |
|---|---|
| N / B | batch size(一次处理多少条数据) |
| C | channel(图像通道数) |
| H | height(图像高度) |
| W | width(图像宽度) |
| T | token 数 / 时间步数 |
| D | hidden size / feature size |
推理代码中常见
(B, T, D)等记号;上表字母与之一一对应。
3. 创建 Tensor:最基本操作
最常用的 5 种方式
1 | import torch |
什么时候用哪个
| 函数 | 什么时候用 |
|---|---|
torch.tensor(data) |
有现成数据(列表、numpy)需要转成 Tensor |
torch.zeros(...) |
初始化权重、mask、填充缓冲区 |
torch.ones(...) |
构造全 1 的 mask 或初始值 |
torch.randn(...) |
测试、随机权重初始化、调试形状 |
torch.empty(...) |
性能优化时,提前分配内存,之后再填值 |
4. shape:张量的"外形"
shape 是什么
shape 给出各维长度。
1 | import torch |
可以把 shape 想象成"这个盒子的尺寸":
1 | x.shape = (2, 3, 4) |
获取某一维的大小
1 | x = torch.randn(2, 3, 4) |
x.shape[-1]取最后一维,在推理代码里非常常见。比如hidden_dim = x.shape[-1]。
养成一个关键习惯
看到一个 Tensor,先问:
- 它有几维?(
x.ndim) - 每一维是多少?(
x.shape) - 每一维代表什么含义?(这个最重要,靠上下文判断)
5. dtype:数据类型
dtype 是什么
dtype 决定 Tensor 里的每个数字用多少 bit 来存储、能表示多大的范围。
1 | import torch |
常见 dtype 对比
| dtype | 位数 | 每个数占内存 | 适用场景 |
|---|---|---|---|
float32 |
32 bit | 4 字节 | 默认精度,训练基线,精度要求高的推理 |
float16 |
16 bit | 2 字节 | GPU 推理加速,省显存 |
bfloat16 |
16 bit | 2 字节 | 动态范围更大,训练更稳,很多 LLM 用这个 |
int8 |
8 bit | 1 字节 | 量化推理,极致省显存 |
int64 |
64 bit | 8 字节 | token id、索引 |
bool |
— | — | mask(哪些位置是有效的) |
为什么 dtype 对推理极其重要
一个模型里的参数量固定,但用不同 dtype 存储,显存占用会差 2~4 倍:
1 | GPT-2(1.5B 参数)用 float32 存储:约 6 GB |
所以推理岗位会大量涉及"量化"(quantization)和"混合精度"(mixed precision),而这一切的核心就是 dtype 的选择。
转换 dtype
1 | x = torch.randn(2, 3) # float32 |
6. device:张量在哪个设备上
device 是什么
PyTorch 中的 Tensor 必须"活"在某个设备上:
- CPU:普通内存(RAM)
- CUDA GPU:显卡显存(VRAM)
1 | import torch |
把 Tensor 移到 GPU
1 | import torch |
最常见的坑:设备不一致
若参与运算的两个张量不在同一设备,将触发运行时错误:
1 | import torch |
解决方法:确保所有参与计算的 Tensor 都在同一设备上。
推理代码里的标准写法
1 | device = "cuda" if torch.cuda.is_available() else "cpu" |
核心原则:参与同一次计算的 Tensor,必须在同一 device 上。
这包括:输入、模型权重、KV cache、mask、position ids。
7. shape 的语义:不要只看数字
同样是 (32, 128),可能含义完全不同:
(32, 128):32 个样本,每个 128 维特征(32, 128):batch=32,seq_len=128(每个样本是 128 个 token)(32, 128):32 个 token,每个 128 维 hidden state
光看 shape 数字是不够的,必须结合上下文理解每一维的语义。
三种常见模型输入的 shape
1 | # MLP 输入 |
维度语义与上下文
同一组 shape 数字在不同模型中的语义不同;阅读实现时应结合变量名、注释与上游调用约定,为各维赋予可解释标签:
1 | # 看到这段代码 |
该习惯有助于对照 Attention 实现与 KV Cache 相关 shape 推导。
8. 广播(broadcasting):形状不同的 Tensor 如何运算
什么是广播
广播让不同形状的 Tensor 也能做运算,PyTorch 会自动"扩展"较小的那个。
1 | import torch |
理解广播的关键:从最后一维开始,向左对齐。
1 | x.shape: (2, 3) |
对每一对维度:如果相等 → 可以运算;如果一个是 1 → 可以扩展;如果不匹配 → 报错。
几个例子
可广播:
1 | # (2, 3) + (3,) → b 扩展成 (2, 3) |
不可广播:
1 | # (2, 3) + (2, 4) → 最后一维 3 和 4 不匹配,也不是 1 |
推理代码里广播无处不在
1 | # Bias 加法(线性层偏置) |
不熟悉广播规则时,shape 相关报错较难从信息中反推原因。
9. 形状变换:reshape、view、transpose、permute
这几个操作是推理代码里最常见的,也是最容易弄混的。
要点
形状变换不一定等于复制数据。
很多时候,PyTorch 只是改变了"如何解读这块内存",数据本身没有移动。
reshape()
最常用的形状变换,把元素总数不变的情况下重排成新形状。
1 | x = torch.arange(12) # tensor([0, 1, 2, ..., 11]),shape=(12,) |
- 总元素数必须一致:
12 == 3×4 - 可以用
-1让 PyTorch 自动推算:x.reshape(3, -1)→(3, 4)
view()
和 reshape() 类似,但有限制:要求 Tensor 是内存连续的(后面详细解释),否则报错。
1 | x = torch.arange(12) |
选哪个? 一般优先用
reshape(),不确定内存连续性时更安全。
但理解view()为什么有时失败,是必须掌握的知识点(见第 10、11 节)。
transpose()
交换两个维度。
1 | x = torch.randn(2, 3) # shape: (2, 3) |
直觉上:就是把矩阵"转置"——行变列、列变行。
permute()
按任意顺序重排多个维度(transpose 只能交换两个,permute 可以重排全部)。
1 | x = torch.randn(2, 3, 4) # shape: (2, 3, 4),维度顺序是 0, 1, 2 |
实际例子:把 (B, H, W, C) 格式的图像转成 (B, C, H, W):
1 | img = torch.randn(8, 224, 224, 3) # (B, H, W, C) 格式 |
10. contiguous:内存连续性
先理解内存里数据是怎么存的
在内存里,Tensor 的数据本质是一段连续的数字序列。
以 shape (2, 3) 的 Tensor 为例:
1 | 内存里实际存的:[a00, a01, a02, a10, a11, a12] |
按行优先(row-major)排列,这是"内存连续"的正常状态:访问第 i 行第 j 列的元素,地址 = 起始地址 + i*3 + j。
transpose 之后发生了什么
1 | x = torch.randn(2, 3) # 连续的,shape=(2, 3) |
transpose 没有真的把数据在内存里重新排列,它只是改变了"用什么规则去读这块内存"。
结果是:y 的 shape 是 (3, 2),但内存布局依然是 (2, 3) 的样子 → 非连续(not contiguous)。
1 | print(x.is_contiguous()) # True |
非连续会导致什么问题
view() 要求 Tensor 内存连续,所以对 transpose 的结果直接 view() 会报错:
1 | x = torch.randn(2, 3) |
解决方法:先用 .contiguous() 真正在内存里重新排列一份,再 view():
1 | z = y.contiguous().view(-1) # 示例: 先让内存变连续,再变形 |
怎么理解这两个操作
| 操作 | 做了什么 | 数据是否复制 |
|---|---|---|
transpose/permute |
改变"读内存的规则" | 否;不复制数据 |
contiguous() |
按新规则把数据真正搬到新地方 | 是;分配新内存 |
结论:
- 在可以的时候,PyTorch 尽量不复制数据(用 view 的方式)
- 但当操作要求内存连续时,就必须显式
.contiguous()触发一次复制- 性能优化时,要注意这种隐式 copy 带来的额外开销
11. stride:理解 contiguous 的底层原理
这一节比较底层,初学时可以先理解大意,后面学 kernel 优化时再深入。
stride 是什么
stride 描述:在某个维度上前进一步,需要在内存里跳过多少个元素。
1 | import torch |
解读 stride = (3, 1):
- 沿第 0 维(行)前进一步 → 要跳过 3 个元素(因为每行有 3 个元素)
- 沿第 1 维(列)前进一步 → 要跳过 1 个元素(紧挨着的下一个)
用内存图来理解
1 | 内存地址: [0] [1] [2] [3] [4] [5] |
transpose 之后 stride 如何变化
1 | x = torch.randn(2, 3) |
transpose 只是把 stride 里的数字对换了顺序,内存本身没变化。
因此 transpose 后的张量常为非 contiguous:stride 与按行主序展开的预期不再一致。
什么叫"正常的"stride
一个 shape 为 (d0, d1, d2) 的 contiguous Tensor,其 stride 应该是:
1 | stride = (d1 * d2, d2, 1) |
如果不满足这个规律,就不是 contiguous。
stride 在推理优化里为什么重要
- 内存访问模式:stride 大意味着跳跃式访问内存,缓存命中率低,性能差
- 算子限制:很多 CUDA kernel 要求 contiguous 输入
- 是否需要 copy:stride 异常 → 不得不插入
.contiguous()→ 额外显存分配和时间
12. 索引与切片:子张量与视图
基本用法
1 | import torch |
重要提醒:切片返回的是 view,不是拷贝
1 | x = torch.arange(6).reshape(2, 3) |
这说明 y 和 x 共享同一块内存,修改 y 会影响 x。
如果不想影响原始数据,用 .clone():
1 | y = x[:, :2].clone() # 真正的拷贝一份 |
13. 常见报错与解读
推理代码几乎所有初级报错,都出在 shape / device / dtype 三件事上。
报错 1:设备不一致
1 | RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! |
含义:参与计算的 Tensor 不在同一个设备(有的在 CPU,有的在 GPU)。
排查:print(x.device, y.device),找到不一致的 Tensor,.to(device) 统一。
报错 2:shape 不匹配
1 | RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1 |
含义:两个 Tensor 某一维大小不同,又不满足广播条件(没有一个是 1)。
排查:print(a.shape, b.shape),看哪一维对不上。
报错 3:view 失败(内存不连续)
1 | RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. |
含义:Tensor 内存不连续,不能直接 view()。
排查:print(x.is_contiguous()),如果 False,先 .contiguous() 或改用 .reshape()。
报错 4:dtype 不匹配
1 | RuntimeError: expected scalar type Float but found Half |
含义:算子要求 float32 输入,但传入了 float16(或者反过来)。
排查:print(x.dtype),用 .to(torch.float32) 转换。
14. 推理岗位必须形成的 6 个习惯
习惯 1:看到任何 Tensor,先打印三件事
1 | print(x.shape, x.dtype, x.device) |
可优先核对下列三项以缩小问题范围。
习惯 2:遇到 reshape / view 报错,立刻查内存连续性
1 | print(x.is_contiguous()) |
习惯 3:不要把广播当"自动的魔法"
广播有规则,遇到 shape 报错,手动推一遍:从最后一维开始对齐,哪一维出问题了?
习惯 4:阅读代码时给每一维加语义标签
1 | # 不要只看 (8, 128, 768) |
习惯 5:模型和数据要一起移到 device
1 | device = "cuda" if torch.cuda.is_available() else "cpu" |
不要忘记任何一个参与计算的 Tensor(mask、position ids、cache……)。
习惯 6:切片/索引后,考虑是否需要 clone
若需独立副本以免共享存储,应使用 .clone()。
15. 串联示例:一个最小推理片段
下面用一段完整的代码把这一部分的核心概念串起来。
场景:把一个 batch 的 token 向量通过一个线性变换,输出新的向量。
1 | import torch |
逐行解释:
| 问题 | 解释 |
|---|---|
为什么 x @ W 合法? |
矩阵乘要求内层维度一致:(4, 8) @ (8, 16) → 8 == 8,合法 |
为什么 + b 合法? |
广播:b.shape=(16,) 自动扩展到 (4, 16) |
为什么输出是 float32? |
所有输入都是 float32,PyTorch 保持精度 |
| 为什么结果是连续的? | @ 运算输出的是全新分配的 Tensor,必然是连续的 |
16. 本节要点与自检
下列能力可作为掌握程度的参照:
- 看懂任意 PyTorch 张量创建代码,知道每种方式的适用场景
- 一眼看出一个 Tensor 的 shape / dtype / device
- 手动推断两个 Tensor 能否广播,以及结果的 shape
- 知道
view()为什么会失败,如何修复 - 理解
contiguous和stride的含义,以及transpose后为什么内存不连续 - 能定位最常见的 shape / device / dtype 报错
- 读推理代码时,能说出每个 Tensor 每一维的语义
17. 思考题
建议在本地运行下列片段,并对照 print 输出与上文定义。
练习 1:创建不同 dtype 的 Tensor
1 | # 创建 float32、float16、int64 的 Tensor,打印 dtype 和内存占用 |
思考:float32 的内存是 float16 的几倍?
练习 2:检查 device
1 | x = torch.randn(3, 4) |
思考:如果忘记
.to(device),后面运算时会报什么错?
练习 3:验证广播
1 | x = torch.randn(2, 3) |
思考:报错信息中哪一句标明了不匹配的维度?
练习 4:验证 contiguous 和 view 报错
1 | x = torch.randn(2, 3) |
思考:
y.contiguous()之后y.stride()变成什么了?
练习 5:给一个 Tensor 加语义标签
1 | # 假设有一个 Transformer 模型的中间激活值 |
进阶思考:如果这是 8 头 Attention 的 query,重新排列后 shape 会变成什么?
18. 小结
推理性能与正确性问题多数可追溯到张量的 shape、dtype、device、stride 与内存连续性;后续精度、设备与图相关章节均以此为基础。
