AI Infra学习之旅-用PyTorch从零实现一个最小Transformer

这是“Transformer 原理深讲系列”的第 8 篇。
前面 7 篇我们已经把 Transformer 的主要结构讲清楚了:
从为什么需要 Transformer,到文本如何变成向量;从位置编码、Q/K/V、Attention、Multi-Head、Mask、FFN、残差、LayerNorm,到训练与推理的区别,以及手算一次 Attention、手算一个完整 block。
现在,是时候把这些纸面上的结构真正写成代码了。
这一篇的目标很明确:
不用现成的 nn.Transformer 黑箱,而是用 PyTorch 自己搭一个最小可运行的 Decoder-only Transformer。


一、为什么要自己写一个最小 Transformer

很多人学 Transformer 时,会经历两个阶段:

阶段一:只会看结构图

知道:

  • 输入先 embedding
  • 再加位置编码
  • 再过 attention
  • 再过 FFN
  • 最后预测下一个 token

但这时候你对 Transformer 的理解仍然偏“概念图”。

阶段二:开始看源码

一打开成熟框架实现,就会看到:

  • 很多工程优化
  • 很多参数细节
  • 很多和主逻辑缠在一起的兼容代码
  • 很多缓存、并行、dtype、device 相关逻辑

这时又很容易迷失在实现细节里。

所以最好的过渡方法不是直接跳进大项目,而是:

先自己写一个最小版本,把“结构”和“代码”一一对应起来。

这篇文章就是为这个目的写的。


二、这一篇我们要实现什么

我们实现的是一个极简但完整的 Decoder-only Transformer,它具备这些组件:

  1. token embedding
  2. 可学习位置 embedding
  3. Causal Self-Attention
  4. multi-head attention
  5. FFN
  6. residual connection
  7. LayerNorm
  8. 多层 Transformer Block 堆叠
  9. 输出 logits 预测下一个 token

也就是说,我们要写出的不是单个模块,而是一套能真正跑 forward 的最小语言模型。


三、我们刻意不做哪些复杂工程

为了保证“原理和代码一一对应”,这篇只做最小实现,不追求工业级完整性。

不展开的内容包括

  • KV Cache
  • FlashAttention
  • mixed precision 优化
  • 分布式并行
  • 量化
  • RoPE
  • RMSNorm
  • GQA / MQA
  • MoE
  • checkpointing
  • compile / CUDA graph

这些东西当然都重要,但它们属于“把 Transformer 做快、做大、做稳”的阶段。
而这一篇聚焦的是:

先把经典 Decoder-only Transformer 的最小骨架亲手搭起来。


四、先给你最终完整代码

下面是一份单文件、可直接运行的最小实现。
你可以把它保存为:

1
mini_transformer.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import math
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


# =========================
# 1. 配置
# =========================
@dataclass
class ModelConfig:
vocab_size: int = 100
max_seq_len: int = 64
d_model: int = 128
n_heads: int = 4
n_layers: int = 4
d_ff: int = 256
dropout: float = 0.1


# =========================
# 2. 单个 Multi-Head Causal Self-Attention
# =========================
class MultiHeadSelfAttention(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
assert config.d_model % config.n_heads == 0, 'd_model 必须能被 n_heads 整除'

self.d_model = config.d_model
self.n_heads = config.n_heads
self.head_dim = config.d_model // config.n_heads

self.q_proj = nn.Linear(config.d_model, config.d_model)
self.k_proj = nn.Linear(config.d_model, config.d_model)
self.v_proj = nn.Linear(config.d_model, config.d_model)
self.out_proj = nn.Linear(config.d_model, config.d_model)

self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)

mask = torch.triu(torch.ones(config.max_seq_len, config.max_seq_len), diagonal=1).bool()
self.register_buffer('causal_mask', mask, persistent=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape

q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)

q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn_scores = attn_scores.masked_fill(self.causal_mask[:T, :T], float('-inf'))

attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.attn_dropout(attn_weights)

out = attn_weights @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
out = self.out_proj(out)
out = self.resid_dropout(out)
return out


# =========================
# 3. FFN / MLP
# =========================
class FeedForward(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.net = nn.Sequential(
nn.Linear(config.d_model, config.d_ff),
nn.GELU(),
nn.Linear(config.d_ff, config.d_model),
nn.Dropout(config.dropout),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)


# =========================
# 4. 单个 Transformer Block(Pre-LN)
# =========================
class TransformerBlock(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.ln_1 = nn.LayerNorm(config.d_model)
self.attn = MultiHeadSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.d_model)
self.ffn = FeedForward(config)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln_1(x))
x = x + self.ffn(self.ln_2(x))
return x


# =========================
# 5. 最小 Decoder-only Transformer
# =========================
class MiniTransformerLM(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config

self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.final_ln = nn.LayerNorm(config.d_model)

self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = self.token_emb.weight

self.apply(self._init_weights)

def _init_weights(self, module: nn.Module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)

def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
B, T = idx.shape
assert T <= self.config.max_seq_len, '输入长度超过 max_seq_len'

positions = torch.arange(0, T, device=idx.device).unsqueeze(0)

tok = self.token_emb(idx)
pos = self.pos_emb(positions)
x = tok + pos

for block in self.blocks:
x = block(x)

x = self.final_ln(x)
logits = self.lm_head(x)

loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(B * T, self.config.vocab_size),
targets.view(B * T)
)

return logits, loss

@torch.no_grad()
def generate(self, idx: torch.Tensor, max_new_tokens: int):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.config.max_seq_len:]
logits, _ = self(idx_cond)
next_token_logits = logits[:, -1, :]
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_token], dim=1)
return idx


# =========================
# 6. 一个最小训练示例
# =========================
def make_toy_data(batch_size: int, seq_len: int, vocab_size: int, device: str):
x = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device=device)
inputs = x[:, :-1]
targets = x[:, 1:]
return inputs, targets


def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'

config = ModelConfig(
vocab_size=50,
max_seq_len=32,
d_model=64,
n_heads=4,
n_layers=2,
d_ff=128,
dropout=0.1,
)

model = MiniTransformerLM(config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

print('Model parameters:', sum(p.numel() for p in model.parameters()))

model.train()
for step in range(50):
inputs, targets = make_toy_data(batch_size=16, seq_len=17, vocab_size=config.vocab_size, device=device)
_, loss = model(inputs, targets)

optimizer.zero_grad()
loss.backward()
optimizer.step()

if step % 10 == 0:
print(f'step={step:03d} loss={loss.item():.4f}')

model.eval()
start = torch.tensor([[1, 2, 3]], device=device)
out = model.generate(start, max_new_tokens=10)
print('Generated token ids:', out.tolist())


if __name__ == '__main__':
main()

五、先从整体看代码结构

虽然这份代码不长,但它已经完整对应了一个最小 Decoder-only Transformer 的主要骨架。

结构可以压缩成:

  1. ModelConfig:定义超参数
  2. MultiHeadSelfAttention:实现 Causal Self-Attention
  3. FeedForward:实现 FFN
  4. TransformerBlock:把 attention + FFN 组织成 block
  5. MiniTransformerLM:把 embedding、多个 block、输出头堆起来
  6. generate():做最朴素的自回归生成
  7. main():给一个最小训练和生成示例

接下来,我们按模块重新拆开看。


六、Embedding 在代码里到底怎么体现

看这两行:

1
2
self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)

它们分别对应:

  • token embedding
  • position embedding

在 forward 里:

1
2
3
tok = self.token_emb(idx)            
pos = self.pos_emb(positions)
x = tok + pos

这就是我们前面博客里讲过的:

hi=xi+pih_i = x_i + p_i

也就是说:

  • token_emb 负责把 token id 映射成内容向量
  • pos_emb 负责给每个位置分配位置向量
  • 两者相加,得到真正送进 Transformer block 的输入表示

这里我们用的是可学习位置 embedding,没有用 RoPE。
原因很简单:
这一篇的目标是把骨架写清楚,而不是追求现代大模型全部细节。


七、Multi-Head Self-Attention 在代码里是怎么落地的

先看最核心的三行:

1
2
3
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)

这正对应前面原理里的:

Q=HWQ,K=HWK,V=HWVQ = HW^Q,\quad K = HW^K,\quad V = HW^V

区别只在于:

  • 数学里我们直接写矩阵乘法
  • PyTorch 里用 nn.Linear 帮你管理参数和乘法

然后是多头 reshape:

1
2
3
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

这一步的意义是:

把原来的 [B, T, C]
拆成
[B, n_heads, T, head_dim]

也就是把总通道维度切分成多个头,让每个头在自己的子空间里做 attention。


八、Attention score 和 causal mask 在代码里怎么对应公式

核心代码是:

1
2
3
attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn_scores = attn_scores.masked_fill(self.causal_mask[:T, :T], float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1)

它对应的数学过程就是:

S=QKTdkS = \frac{QK^T}{\sqrt{d_k}}

然后加上 causal mask:

S^=S+M\hat{S} = S + M

再做:

A=softmax(S^)A = \text{softmax}(\hat{S})

这里你要特别注意 masked_fill 的作用。
我们之前讲过,causal mask 的本质是:

  • 当前位置不能看未来位置
  • 所以未来位置在 softmax 前必须被置成 (-\infty)

在代码里:

1
self.causal_mask = torch.triu(torch.ones(...), diagonal=1).bool()

生成的是上三角 True 区域,表示:

  • 这些位置属于“未来”
  • 必须屏蔽

所以这一段代码就是 Causal Self-Attention 的真正落地。


九、Value 聚合和输出投影在代码里怎么体现

接着看:

1
2
3
out = attn_weights @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
out = self.out_proj(out)

这对应的数学过程是:

O=AVO = AV

然后把所有头拼接回来,再做一次输出投影:

MultiHead(H)=Concat(head1,,headh)WO\text{MultiHead}(H)=\text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O

也就是说,代码和公式是一一对应的:

  • attn_weights @ v:每个头内部做 value 聚合
  • transpose + view:把多头结果拼回主通道
  • out_proj:做最终输出融合

十、FFN 在代码里怎么对应前面的数学讲解

看这个模块:

1
2
3
4
5
6
self.net = nn.Sequential(
nn.Linear(config.d_model, config.d_ff),
nn.GELU(),
nn.Linear(config.d_ff, config.d_model),
nn.Dropout(config.dropout),
)

这正对应:

FFN(x)=W2σ(W1x+b1)+b2\text{FFN}(x)=W_2 \sigma(W_1x+b_1)+b_2

只是这里激活函数用的是 GELU,而不是最简单的 ReLU。
这是现代 Transformer / LLM 里很常见的选择。

你可以从这段代码里直接看出我们前面讲的三个关键点:

  1. 先升维:d_model -> d_ff
  2. 再非线性激活
  3. 再降维:d_ff -> d_model

这说明 FFN 的本质,在代码里同样没有变化:

先把每个位置的表示展开到更高维空间,再做非线性加工,再压回主空间。


十一、为什么 block 用的是 Pre-LN

TransformerBlock

1
2
x = x + self.attn(self.ln_1(x))
x = x + self.ffn(self.ln_2(x))

这是非常标准的 Pre-LN Transformer Block

对应数学写法:

h=x+Attention(LN(x))h = x + \text{Attention}(\text{LN}(x))

y=h+FFN(LN(h))y = h + \text{FFN}(\text{LN}(h))

为什么这样写?

因为现代很多大模型更偏好 Pre-LN,它在深层训练时通常更稳定。
而且这也和我们前面原理博客的解释完全一致:

  • 先 LN 让子层看到稳定输入
  • 子层输出作为修正量
  • 再通过残差叠加回主表示

所以你现在应该能真正把:

  • 公式
  • 结构图
  • 代码

连成一条线。


十二、MiniTransformerLM 这一层壳子到底做了什么

这个类本质上是在做三件事:

1. 处理输入

1
2
3
tok = self.token_emb(idx)
pos = self.pos_emb(positions)
x = tok + pos

2. 通过多层 block 做深层加工

1
2
for block in self.blocks:
x = block(x)

3. 输出 logits

1
2
x = self.final_ln(x)
logits = self.lm_head(x)

下面就是一个最小 Decoder-only Transformer 的完整数据流:

token idsembedding + pos多个 Transformer blocksfinal LNlogits\text{token ids} \rightarrow \text{embedding + pos} \rightarrow \text{多个 Transformer blocks} \rightarrow \text{final LN} \rightarrow \text{logits}

其中 lm_head 的作用,就是把隐藏状态重新投影回词表维度,得到:

ztRVz_t \in \mathbb{R}^{V}

这正对应语言模型训练里每个位置的 next-token 分布。


十三、为什么这里用了权重共享

代码里有这句:

1
self.lm_head.weight = self.token_emb.weight

这表示输出头和 token embedding 共享参数。
很多语言模型都会这样做。

为什么可行?

因为:

  • 输入 embedding:把 token id 映射到向量空间
  • 输出头:把向量空间映射回词表打分空间

它们在数学上有很强的对偶关系,所以共享参数往往既合理,又能减少参数量。

这一点不是理解 Transformer 骨架的必要条件,但它是一个很常见、也很有代表性的实现细节。


十四、训练时 forward 和 loss 是怎么对应前面理论的

看这一段:

1
2
3
4
loss = F.cross_entropy(
logits.view(B * T, self.config.vocab_size),
targets.view(B * T)
)

这对应前面原理里讲的:

  • 每个位置输出一个词表大小维度 logits
  • 把每个位置都看成一次分类
  • 目标是预测“下一个 token”

这里之所以要 view(B * T, vocab_size),是因为交叉熵函数常常希望输入形状是:

  • [样本数, 类别数]

而对语言模型来说:

  • 每个 batch 中的每个位置,其实都是一个训练样本

所以把 [B, T, V] 展平成 [B*T, V] 是非常自然的。

这也正对应了 Teacher Forcing 训练的本质:

在完整已知序列上,并行优化所有位置的 next-token 预测。


十五、为什么 generate() 这么写,刚好对应前面讲过的推理本质

看生成函数:

1
2
3
4
5
6
7
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.config.max_seq_len:]
logits, _ = self(idx_cond)
next_token_logits = logits[:, -1, :]
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_token], dim=1)

你应该已经能看懂,这几乎就是前面推理原理的代码翻译版:

  1. 拿当前前缀
  2. 跑一遍模型
  3. 取最后一个位置的 logits
  4. 得到下一个 token 的分布
  5. 采样一个 token
  6. 拼回输入
  7. 继续下一轮

这就是最朴素的自回归生成。

同时你也能一眼看出它的低效之处:

  • 每一步都重新把整个前缀送进模型
  • 没有 KV Cache
  • 所以随着序列变长,会越来越慢

而这,恰恰对应了我们前面讲的:

为什么推理系统里必须引入 KV Cache。

所以这段代码虽然“笨”,但它非常有价值,因为它把原理和推理瓶颈暴露得非常清楚。


十六、这份最小实现和真实大模型实现相比,缺了什么

到这里你已经有了一个真正能跑的最小 Transformer。
但它距离工业级 LLM 当然还有很长的距离。

真实大模型一般还会进一步加入很多东西,例如:

结构层

  • RoPE
  • RMSNorm
  • SwiGLU
  • GQA / MQA
  • MoE

推理层

  • KV Cache
  • FlashAttention
  • PagedAttention
  • prefix caching

系统层

  • continuous batching
  • tensor / pipeline / expert parallel
  • PD 分离
  • Distributed KV Cache

但这不影响这份代码的价值。
因为它让你先抓住了那个最重要的骨架:

如果骨架不清楚,后面所有优化都会变成“记工程名词”;如果骨架清楚,后面所有优化都能顺着结构自然长出来。


十七、你现在应该如何使用这份代码

这份代码的正确用法,不只是“复制运行”。
更重要的是你可以拿它做下面这些练习。

1. 改小模型尺寸,打印每层 shape

例如在 forward() 里加 print,看每个张量的维度变化。

2. 去掉 Multi-Head,只保留单头

观察代码如何退化成最基础的 attention。

3. 把 GELU 换成 ReLU

体会 FFN 激活函数的替换。

4. 把可学习位置 embedding 换成 sinusoidal position encoding

把这一篇和前面位置编码原理真正连接起来。

5. 给 generate() 加 KV Cache

这是从“原理代码”迈向“推理工程代码”的最好练习之一。


十八、从这篇代码里,你应该真正学到的不是“会抄实现”,而是“会对结构做映射”

如果你只是抄一遍这份代码,它的价值很有限。
这篇真正想让你建立的能力是:

看到数学公式时,能想到对应代码;
看到代码实现时,能反推它对应的结构功能。

例如:

  • 看到 q_proj/k_proj/v_proj,你应该想到 Q/K/V 投影
  • 看到 masked_fill(..., -inf),你应该想到 causal mask
  • 看到 x + self.attn(...),你应该想到残差连接
  • 看到 nn.LayerNorm(...),你应该想到 Pre-LN / Post-LN 的组织方式
  • 看到 generate() 的循环,你应该想到自回归 decode 的串行本质

这才是真正从“看懂博客”走向“能自己写模型”的关键一步。


十九、本篇真正要记住的五条主线

第一条:最小 Transformer 代码骨架

  • embedding
  • multi-head Causal Self-Attention
  • FFN
  • Transformer Block 堆叠
  • 输出 logits

第二条:代码和数学公式是一一对应的

  • q_proj/k_proj/v_proj 对应 Q/K/V
  • q @ k.transpose(...) / sqrt(d) 对应 score
  • softmax 对应注意力权重
  • attn_weights @ v 对应 value 聚合

第三条:block 的组织方式

  • Pre-LN
  • attention 残差
  • FFN 残差

第四条:训练对应 Teacher Forcing

  • 完整前缀已知
  • 每个位置都参与 next-token loss

第五条:朴素 generate() 会暴露推理的真正瓶颈

  • 每一步都重算前缀
  • 没有 KV Cache
  • 因此天然低效

二十、用一句话压缩本篇

用 PyTorch 从零实现一个最小 Decoder-only Transformer 的关键,不是记住某段现成代码,而是把 embedding、Q/K/V、causal attention、FFN、残差、LayerNorm 和自回归生成这些结构原理,一一映射成可运行的张量计算过程。


二十一、下一篇最自然的衔接

如果继续写下去,最自然的下一篇是:

从经典 Transformer 到现代大模型变体:RoPE、RMSNorm、GQA/MQA、SwiGLU 与 MoE

因为到这里,你已经真正掌握了“经典骨架”。
下一步最值得做的,就是回答:

  • 为什么今天的大模型已经不完全长得像 2017 年论文里的 Transformer
  • RoPE 到底替代了什么
  • RMSNorm 和 LayerNorm 有什么区别
  • GQA / MQA 为什么会和 KV Cache 强相关
  • SwiGLU 为什么比普通 FFN 更常见
  • MoE 为什么会把问题从模型设计推向系统与通信