PyTorch 推理工程(01):Tensor、dtype、device 与推理底层基础

1. 本节定位

在实现层面,推理路径可描述为 Tensor 上的算子序列:输入经若干运算得到输出。各步骤均涉及 device、dtype、shape 以及内存布局(含 stride、是否 contiguous)。后续关于混合精度、CUDA 与图优化的讨论均依赖上述属性。

下列问题应能依据张量语义作答,否则容易在现象与根因之间脱节:

  • 张量位于 CPU 还是 GPU?
  • dtype 取 float32 / float16 / bfloat16 的依据与取舍;
  • view 报错、transpose 后性能变化等常见情况的语义来源。

2. 什么是 Tensor

由标量、向量到张量

若已熟悉下列对象,可将张量视为其推广:

名称 描述 例子
标量 一个数 3.14
向量 一维数组 [1, 2, 3]
矩阵 二维数组 [[1,2],[3,4]]

张量为上述对象的推广:维度任意。

  • 0 维 Tensor = 标量(一个数)
  • 1 维 Tensor = 向量
  • 2 维 Tensor = 矩阵
  • 3 维及以上 = 更高维数据

Tensor 在深度学习里长什么样

典型 shape 约定示例:

1
2
3
4
5
6
7
8
9
10
11
一张灰度图:(H, W)
例:(224, 224) → 224 行 × 224 列

一张彩色图:(C, H, W)
例:(3, 224, 224) → 3 个通道(R/G/B),每通道 224×224

一批图片: (N, C, H, W)
例:(32, 3, 224, 224) → 32 张图片

一批文本: (B, T, D)
例:(8, 512, 768) → 8 条句子,每条 512 个 token,每个 token 768 维

其中常用字母含义:

字母 含义
N / B batch size(一次处理多少条数据)
C channel(图像通道数)
H height(图像高度)
W width(图像宽度)
T token 数 / 时间步数
D hidden size / feature size

推理代码中常见 (B, T, D) 等记号;上表字母与之一一对应。


3. 创建 Tensor:最基本操作

最常用的 5 种方式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch

# 方式 1:从已有数据创建
a = torch.tensor([1, 2, 3])
print(a)
# 输出:tensor([1, 2, 3])

# 方式 2:全 0
b = torch.zeros(2, 3)
print(b)
# 输出:tensor([[0., 0., 0.],
# [0., 0., 0.]])

# 方式 3:全 1
c = torch.ones(2, 3)
print(c)
# 输出:tensor([[1., 1., 1.],
# [1., 1., 1.]])

# 方式 4:标准正态分布随机值
d = torch.randn(2, 3)
print(d)
# 输出:形如 tensor([[ 0.32, -1.21, 0.87], ...]),每次不同

# 方式 5:未初始化(速度快,但值不可靠,不要直接用里面的值)
e = torch.empty(2, 3)
print(e)
# 输出:里面的值是内存里的随机残留,不要依赖它

什么时候用哪个

函数 什么时候用
torch.tensor(data) 有现成数据(列表、numpy)需要转成 Tensor
torch.zeros(...) 初始化权重、mask、填充缓冲区
torch.ones(...) 构造全 1 的 mask 或初始值
torch.randn(...) 测试、随机权重初始化、调试形状
torch.empty(...) 性能优化时,提前分配内存,之后再填值

4. shape:张量的"外形"

shape 是什么

shape 给出各维长度。

1
2
3
4
5
6
7
import torch

x = torch.randn(2, 3, 4)

print(x.shape) # torch.Size([2, 3, 4])
print(x.ndim) # 3 (有几个维度)
print(x.numel()) # 24 (一共有多少个元素:2×3×4)

可以把 shape 想象成"这个盒子的尺寸":

1
2
3
4
5
x.shape = (2, 3, 4)
↑ ↑ ↑
│ │ └── 最内层:每个 "行" 里有 4 个数
│ └───── 中间层:每个 "面" 里有 3 行
└──────── 最外层:一共有 2 个 "面"

获取某一维的大小

1
2
3
4
5
6
x = torch.randn(2, 3, 4)

print(x.shape[0]) # 2(第 0 维)
print(x.shape[1]) # 3(第 1 维)
print(x.shape[2]) # 4(第 2 维)
print(x.shape[-1]) # 4(最后一维,常用)

x.shape[-1] 取最后一维,在推理代码里非常常见。比如 hidden_dim = x.shape[-1]

养成一个关键习惯

看到一个 Tensor,先问:

  1. 它有几维?(x.ndim
  2. 每一维是多少?(x.shape
  3. 每一维代表什么含义?(这个最重要,靠上下文判断)

5. dtype:数据类型

dtype 是什么

dtype 决定 Tensor 里的每个数字用多少 bit 来存储、能表示多大的范围。

1
2
3
4
5
6
7
import torch

x = torch.tensor([1.0, 2.0, 3.0]) # 默认 float32
print(x.dtype) # torch.float32

y = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16)
print(y.dtype) # torch.float16

常见 dtype 对比

dtype 位数 每个数占内存 适用场景
float32 32 bit 4 字节 默认精度,训练基线,精度要求高的推理
float16 16 bit 2 字节 GPU 推理加速,省显存
bfloat16 16 bit 2 字节 动态范围更大,训练更稳,很多 LLM 用这个
int8 8 bit 1 字节 量化推理,极致省显存
int64 64 bit 8 字节 token id、索引
bool mask(哪些位置是有效的)

为什么 dtype 对推理极其重要

一个模型里的参数量固定,但用不同 dtype 存储,显存占用会差 2~4 倍

1
2
3
GPT-2(1.5B 参数)用 float32 存储:约 6 GB
GPT-2(1.5B 参数)用 float16 存储:约 3 GB
GPT-2(1.5B 参数)用 int8 存储: 约 1.5 GB

所以推理岗位会大量涉及"量化"(quantization)和"混合精度"(mixed precision),而这一切的核心就是 dtype 的选择。

转换 dtype

1
2
3
4
5
6
7
x = torch.randn(2, 3)                   # float32
y = x.to(torch.float16) # 转 float16
z = x.half() # 同样是转 float16(简写)

print(x.dtype) # torch.float32
print(y.dtype) # torch.float16
print(z.dtype) # torch.float16

6. device:张量在哪个设备上

device 是什么

PyTorch 中的 Tensor 必须"活"在某个设备上:

  • CPU:普通内存(RAM)
  • CUDA GPU:显卡显存(VRAM)
1
2
3
4
import torch

x = torch.randn(2, 3)
print(x.device) # cpu ← 默认在 CPU 上

把 Tensor 移到 GPU

1
2
3
4
5
6
7
8
9
import torch

x = torch.randn(2, 3)
print(x.device) # cpu

if torch.cuda.is_available():
y = x.to("cuda")
print(y.device) # cuda:0
# cuda:0 表示第 0 块 GPU(有多块 GPU 时,可以有 cuda:1, cuda:2...)

最常见的坑:设备不一致

若参与运算的两个张量不在同一设备,将触发运行时错误:

1
2
3
4
5
6
7
import torch

x = torch.randn(2, 3) # 在 CPU 上
w = torch.randn(3, 4).to("cuda") # 在 GPU 上

y = x @ w # 反例: 报错!
# RuntimeError: Expected all tensors to be on the same device

解决方法:确保所有参与计算的 Tensor 都在同一设备上。

推理代码里的标准写法

1
2
3
4
5
6
7
8
device = "cuda" if torch.cuda.is_available() else "cpu"

# 模型和输入都搬到同一设备
model = model.to(device)
x = x.to(device)

# 推理
output = model(x)

核心原则:参与同一次计算的 Tensor,必须在同一 device 上。
这包括:输入、模型权重、KV cache、mask、position ids。


7. shape 的语义:不要只看数字

同样是 (32, 128),可能含义完全不同:

  • (32, 128):32 个样本,每个 128 维特征
  • (32, 128):batch=32,seq_len=128(每个样本是 128 个 token)
  • (32, 128):32 个 token,每个 128 维 hidden state

光看 shape 数字是不够的,必须结合上下文理解每一维的语义。

三种常见模型输入的 shape

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# MLP 输入
x.shape == (batch_size, hidden_dim)
# ↑ ↑
# 样本数 每个样本的特征维度

# Transformer / LLM 输入
x.shape == (batch_size, seq_len, hidden_dim)
# ↑ ↑ ↑
# 样本数 token数 每个token的维度

# 图像模型输入(CNN)
x.shape == (batch_size, channels, height, width)
# ↑ ↑ ↑ ↑
# 样本数 图像通道 高 宽

维度语义与上下文

同一组 shape 数字在不同模型中的语义不同;阅读实现时应结合变量名、注释与上游调用约定,为各维赋予可解释标签:

1
2
3
4
# 看到这段代码
q = q.view(batch_size, seq_len, num_heads, head_dim)
# ↑ ↑
# 多头注意力头数 每头的维度

该习惯有助于对照 Attention 实现与 KV Cache 相关 shape 推导。


8. 广播(broadcasting):形状不同的 Tensor 如何运算

什么是广播

广播让不同形状的 Tensor 也能做运算,PyTorch 会自动"扩展"较小的那个。

1
2
3
4
5
6
7
import torch

x = torch.ones(2, 3) # shape: (2, 3)
b = torch.ones(3) # shape: (3,)

y = x + b # ← b 被自动"扩展"成 (2, 3) 后再相加
print(y.shape) # torch.Size([2, 3])

理解广播的关键:从最后一维开始,向左对齐

1
2
x.shape:  (2, 3)
b.shape: (3,) ← 从右对齐

对每一对维度:如果相等 → 可以运算;如果一个是 1 → 可以扩展;如果不匹配 → 报错。

几个例子

可广播

1
2
3
4
5
6
7
8
9
# (2, 3) + (3,)  → b 扩展成 (2, 3)
x = torch.randn(2, 3)
b = torch.randn(3)
print((x + b).shape) # (2, 3)

# (4, 1, 8) + (1, 6, 8) → 结果是 (4, 6, 8)
x = torch.randn(4, 1, 8)
y = torch.randn(1, 6, 8)
print((x + y).shape) # (4, 6, 8)

不可广播

1
2
3
4
# (2, 3) + (2, 4) → 最后一维 3 和 4 不匹配,也不是 1
x = torch.randn(2, 3)
y = torch.randn(2, 4)
# z = x + y # 会报错

推理代码里广播无处不在

1
2
3
4
5
6
7
8
# Bias 加法(线性层偏置)
logits = x @ W + b # b.shape = (out_features,) 自动广播

# Attention mask 加法
scores = scores + mask # mask 通常小于 scores,需要广播

# Scale 操作
scores = scores * scale # scale 是一个标量,广播到整个 Tensor

不熟悉广播规则时,shape 相关报错较难从信息中反推原因。


9. 形状变换:reshape、view、transpose、permute

这几个操作是推理代码里最常见的,也是最容易弄混的。

要点

形状变换不一定等于复制数据。
很多时候,PyTorch 只是改变了"如何解读这块内存",数据本身没有移动。

reshape()

最常用的形状变换,把元素总数不变的情况下重排成新形状。

1
2
3
4
5
6
7
x = torch.arange(12)   # tensor([0, 1, 2, ..., 11]),shape=(12,)
y = x.reshape(3, 4) # shape=(3, 4)

print(y)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
  • 总元素数必须一致:12 == 3×4
  • 可以用 -1 让 PyTorch 自动推算:x.reshape(3, -1)(3, 4)

view()

reshape() 类似,但有限制:要求 Tensor 是内存连续的(后面详细解释),否则报错。

1
2
3
4
x = torch.arange(12)
y = x.view(3, 4) # 正常情况下和 reshape 结果一样

print(y.shape) # torch.Size([3, 4])

选哪个? 一般优先用 reshape(),不确定内存连续性时更安全。
但理解 view() 为什么有时失败,是必须掌握的知识点(见第 10、11 节)。

transpose()

交换两个维度。

1
2
3
4
x = torch.randn(2, 3)        # shape: (2, 3)
y = x.transpose(0, 1) # 交换第 0 维和第 1 维

print(y.shape) # torch.Size([3, 2])

直觉上:就是把矩阵"转置"——行变列、列变行。

permute()

按任意顺序重排多个维度(transpose 只能交换两个,permute 可以重排全部)。

1
2
3
4
x = torch.randn(2, 3, 4)    # shape: (2, 3, 4),维度顺序是 0, 1, 2
y = x.permute(2, 0, 1) # 重排为 2, 0, 1

print(y.shape) # torch.Size([4, 2, 3])

实际例子:把 (B, H, W, C) 格式的图像转成 (B, C, H, W)

1
2
3
img = torch.randn(8, 224, 224, 3)   # (B, H, W, C) 格式
img = img.permute(0, 3, 1, 2) # → (B, C, H, W)
print(img.shape) # torch.Size([8, 3, 224, 224])

10. contiguous:内存连续性

先理解内存里数据是怎么存的

在内存里,Tensor 的数据本质是一段连续的数字序列

以 shape (2, 3) 的 Tensor 为例:

1
2
3
内存里实际存的:[a00, a01, a02, a10, a11, a12]
↑ ↑
第 0 行 第 1 行

按行优先(row-major)排列,这是"内存连续"的正常状态:访问第 i 行第 j 列的元素,地址 = 起始地址 + i*3 + j

transpose 之后发生了什么

1
2
x = torch.randn(2, 3)        # 连续的,shape=(2, 3)
y = x.transpose(0, 1) # shape=(3, 2),但内存没动!

transpose 没有真的把数据在内存里重新排列,它只是改变了"用什么规则去读这块内存"。

结果是:y 的 shape 是 (3, 2),但内存布局依然是 (2, 3) 的样子 → 非连续(not contiguous)。

1
2
print(x.is_contiguous())   # True
print(y.is_contiguous()) # False ← 虽然形状对,但内存顺序"断了"

非连续会导致什么问题

view() 要求 Tensor 内存连续,所以对 transpose 的结果直接 view() 会报错:

1
2
3
4
5
x = torch.randn(2, 3)
y = x.transpose(0, 1) # shape=(3, 2),非连续

z = y.view(-1) # 反例: 报错:
# RuntimeError: view size is not compatible with input tensor's size and stride

解决方法:先用 .contiguous() 真正在内存里重新排列一份,再 view()

1
z = y.contiguous().view(-1)   # 示例: 先让内存变连续,再变形

怎么理解这两个操作

操作 做了什么 数据是否复制
transpose/permute 改变"读内存的规则" 否;不复制数据
contiguous() 按新规则把数据真正搬到新地方 是;分配新内存

结论

  • 在可以的时候,PyTorch 尽量不复制数据(用 view 的方式)
  • 但当操作要求内存连续时,就必须显式 .contiguous() 触发一次复制
  • 性能优化时,要注意这种隐式 copy 带来的额外开销

11. stride:理解 contiguous 的底层原理

这一节比较底层,初学时可以先理解大意,后面学 kernel 优化时再深入。

stride 是什么

stride 描述:在某个维度上前进一步,需要在内存里跳过多少个元素

1
2
3
4
5
import torch

x = torch.randn(2, 3)
print(x.shape) # torch.Size([2, 3])
print(x.stride()) # (3, 1)

解读 stride = (3, 1)

  • 沿第 0 维(行)前进一步 → 要跳过 3 个元素(因为每行有 3 个元素)
  • 沿第 1 维(列)前进一步 → 要跳过 1 个元素(紧挨着的下一个)

用内存图来理解

1
2
3
4
5
6
7
8
9
10
内存地址:  [0]   [1]   [2]   [3]   [4]   [5]
实际数据: a00 a01 a02 a10 a11 a12
↑ ↑
第一行结束 第二行结束

x[0][0] = 地址 0 x[0][1] = 地址 1 x[0][2] = 地址 2
x[1][0] = 地址 3 x[1][1] = 地址 4 x[1][2] = 地址 5

行内相邻元素跳 1 → stride[1] = 1 ✓
行间相邻元素跳 3 → stride[0] = 3 ✓

transpose 之后 stride 如何变化

1
2
3
4
5
x = torch.randn(2, 3)
print(x.stride()) # (3, 1)

y = x.transpose(0, 1)
print(y.stride()) # (1, 3) ← stride 被交换了

transpose 只是把 stride 里的数字对换了顺序,内存本身没变化。

因此 transpose 后的张量常为非 contiguous:stride 与按行主序展开的预期不再一致

什么叫"正常的"stride

一个 shape 为 (d0, d1, d2) 的 contiguous Tensor,其 stride 应该是:

1
stride = (d1 * d2,  d2,  1)

如果不满足这个规律,就不是 contiguous。

stride 在推理优化里为什么重要

  • 内存访问模式:stride 大意味着跳跃式访问内存,缓存命中率低,性能差
  • 算子限制:很多 CUDA kernel 要求 contiguous 输入
  • 是否需要 copy:stride 异常 → 不得不插入 .contiguous() → 额外显存分配和时间

12. 索引与切片:子张量与视图

基本用法

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch

x = torch.arange(12).reshape(3, 4)
print(x)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])

print(x[0]) # 第 0 行:tensor([0, 1, 2, 3])
print(x[:, 1]) # 第 1 列:tensor([1, 5, 9])
print(x[1:3, 2:4]) # 子块:
# tensor([[ 6, 7],
# [10, 11]])

重要提醒:切片返回的是 view,不是拷贝

1
2
3
4
5
6
7
8
9
10
11
x = torch.arange(6).reshape(2, 3)
print(x)
# tensor([[0, 1, 2],
# [3, 4, 5]])

y = x[:, :2] # 取前两列
y[0, 0] = 100 # 修改 y

print(x) # x 也被改了!
# tensor([[100, 1, 2],
# [ 3, 4, 5]])

这说明 yx 共享同一块内存,修改 y 会影响 x

如果不想影响原始数据,用 .clone()

1
y = x[:, :2].clone()  # 真正的拷贝一份

13. 常见报错与解读

推理代码几乎所有初级报错,都出在 shape / device / dtype 三件事上。

报错 1:设备不一致

1
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

含义:参与计算的 Tensor 不在同一个设备(有的在 CPU,有的在 GPU)。
排查print(x.device, y.device),找到不一致的 Tensor,.to(device) 统一。

报错 2:shape 不匹配

1
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1

含义:两个 Tensor 某一维大小不同,又不满足广播条件(没有一个是 1)。
排查print(a.shape, b.shape),看哪一维对不上。

报错 3:view 失败(内存不连续)

1
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

含义:Tensor 内存不连续,不能直接 view()
排查print(x.is_contiguous()),如果 False,先 .contiguous() 或改用 .reshape()

报错 4:dtype 不匹配

1
RuntimeError: expected scalar type Float but found Half

含义:算子要求 float32 输入,但传入了 float16(或者反过来)。
排查print(x.dtype),用 .to(torch.float32) 转换。


14. 推理岗位必须形成的 6 个习惯

习惯 1:看到任何 Tensor,先打印三件事

1
print(x.shape, x.dtype, x.device)

可优先核对下列三项以缩小问题范围。

习惯 2:遇到 reshape / view 报错,立刻查内存连续性

1
2
print(x.is_contiguous())
print(x.stride())

习惯 3:不要把广播当"自动的魔法"

广播有规则,遇到 shape 报错,手动推一遍:从最后一维开始对齐,哪一维出问题了?

习惯 4:阅读代码时给每一维加语义标签

1
2
3
# 不要只看 (8, 128, 768)
# 要在脑子里翻译成:
# (batch=8, seq_len=128, hidden=768)

习惯 5:模型和数据要一起移到 device

1
2
3
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
inputs = inputs.to(device)

不要忘记任何一个参与计算的 Tensor(mask、position ids、cache……)。

习惯 6:切片/索引后,考虑是否需要 clone

若需独立副本以免共享存储,应使用 .clone()


15. 串联示例:一个最小推理片段

下面用一段完整的代码把这一部分的核心概念串起来。

场景:把一个 batch 的 token 向量通过一个线性变换,输出新的向量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch

# 选择设备
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备:{device}")

# 构造输入:4 个样本,每个 8 维
x = torch.randn(4, 8, dtype=torch.float32)

# 构造权重和偏置
W = torch.randn(8, 16, dtype=torch.float32) # 线性层权重
b = torch.randn(16, dtype=torch.float32) # 偏置,shape=(16,)

# 全部移到同一 device
x = x.to(device)
W = W.to(device)
b = b.to(device)

# 计算:y = x @ W + b
y = x @ W + b
# x @ W: (4, 8) @ (8, 16) → (4, 16)
# + b: (4, 16) + (16,) → (4, 16) ← 这里发生了广播

print("y.shape =", y.shape) # torch.Size([4, 16])
print("y.dtype =", y.dtype) # torch.float32
print("y.device =", y.device) # cuda:0 (或 cpu)
print("y.is_contiguous =", y.is_contiguous()) # True
print("y.stride =", y.stride()) # (16, 1) ← 正常的连续 stride

逐行解释

问题 解释
为什么 x @ W 合法? 矩阵乘要求内层维度一致:(4, 8) @ (8, 16) → 8 == 8,合法
为什么 + b 合法? 广播:b.shape=(16,) 自动扩展到 (4, 16)
为什么输出是 float32 所有输入都是 float32,PyTorch 保持精度
为什么结果是连续的? @ 运算输出的是全新分配的 Tensor,必然是连续的

16. 本节要点与自检

下列能力可作为掌握程度的参照:

  • 看懂任意 PyTorch 张量创建代码,知道每种方式的适用场景
  • 一眼看出一个 Tensor 的 shape / dtype / device
  • 手动推断两个 Tensor 能否广播,以及结果的 shape
  • 知道 view() 为什么会失败,如何修复
  • 理解 contiguousstride 的含义,以及 transpose 后为什么内存不连续
  • 能定位最常见的 shape / device / dtype 报错
  • 读推理代码时,能说出每个 Tensor 每一维的语义

17. 思考题

建议在本地运行下列片段,并对照 print 输出与上文定义。

练习 1:创建不同 dtype 的 Tensor

1
2
3
4
5
6
7
# 创建 float32、float16、int64 的 Tensor,打印 dtype 和内存占用
x32 = torch.ones(1000, 1000, dtype=torch.float32)
x16 = torch.ones(1000, 1000, dtype=torch.float16)
xi = torch.ones(1000, 1000, dtype=torch.int64)

# 提示:用 x.element_size() 查看每个元素多少字节
# 用 x.element_size() * x.numel() 算总字节数

思考:float32 的内存是 float16 的几倍?

练习 2:检查 device

1
2
3
4
5
6
7
8
x = torch.randn(3, 4)
print(x.device) # 应该是 cpu

if torch.cuda.is_available():
x = x.to("cuda")
print(x.device) # 应该是 cuda:0
else:
print("当前环境没有 GPU")

思考:如果忘记 .to(device),后面运算时会报什么错?

练习 3:验证广播

1
2
3
4
5
6
7
8
9
x = torch.randn(2, 3)
b = torch.randn(3)
print((x + b).shape) # 应该是 (2, 3)

# 尝试一个不能广播的例子
x2 = torch.randn(2, 3)
y2 = torch.randn(2, 4)
# 取消注释下面这行,看报错信息
# print((x2 + y2).shape)

思考:报错信息中哪一句标明了不匹配的维度?

练习 4:验证 contiguous 和 view 报错

1
2
3
4
5
6
7
8
9
10
11
12
13
14
x = torch.randn(2, 3)
y = x.transpose(0, 1)

print(x.is_contiguous()) # True
print(y.is_contiguous()) # False
print(x.stride()) # (3, 1)
print(y.stride()) # (1, 3)

# 取消注释下面这行,观察报错
# z = y.view(-1)

# 正确做法:
z = y.contiguous().view(-1)
print(z.shape) # (6,)

思考:y.contiguous() 之后 y.stride() 变成什么了?

练习 5:给一个 Tensor 加语义标签

1
2
3
4
5
# 假设有一个 Transformer 模型的中间激活值
x = torch.randn(8, 128, 768)

# 问:三个维度分别最可能代表什么?
# 提示:B=batch, T=seq_len, D=hidden_dim

进阶思考:如果这是 8 头 Attention 的 query,重新排列后 shape 会变成什么?


18. 小结

推理性能与正确性问题多数可追溯到张量的 shape、dtype、device、stride 与内存连续性;后续精度、设备与图相关章节均以此为基础。


系列导航