这是“Transformer 原理深讲系列”的第 8 篇。
前面 7 篇我们已经把 Transformer 的主要结构讲清楚了:
从为什么需要 Transformer,到文本如何变成向量;从位置编码、Q/K/V、Attention、Multi-Head、Mask、FFN、残差、LayerNorm,到训练与推理的区别,以及手算一次 Attention、手算一个完整 block。
现在,是时候把这些纸面上的结构真正写成代码了。
这一篇的目标很明确:
不用现成的 nn.Transformer 黑箱,而是用 PyTorch 自己搭一个最小可运行的 Decoder-only Transformer。
很多人学 Transformer 时,会经历两个阶段:
阶段一:只会看结构图
知道:
- 输入先 embedding
- 再加位置编码
- 再过 attention
- 再过 FFN
- 最后预测下一个 token
但这时候你对 Transformer 的理解仍然偏“概念图”。
阶段二:开始看源码
一打开成熟框架实现,就会看到:
- 很多工程优化
- 很多参数细节
- 很多和主逻辑缠在一起的兼容代码
- 很多缓存、并行、dtype、device 相关逻辑
这时又很容易迷失在实现细节里。
所以最好的过渡方法不是直接跳进大项目,而是:
先自己写一个最小版本,把“结构”和“代码”一一对应起来。
这篇文章就是为这个目的写的。
二、这一篇我们要实现什么
我们实现的是一个极简但完整的 Decoder-only Transformer,它具备这些组件:
- token embedding
- 可学习位置 embedding
- Causal Self-Attention
- multi-head attention
- FFN
- residual connection
- LayerNorm
- 多层 Transformer Block 堆叠
- 输出 logits 预测下一个 token
也就是说,我们要写出的不是单个模块,而是一套能真正跑 forward 的最小语言模型。
三、我们刻意不做哪些复杂工程
为了保证“原理和代码一一对应”,这篇只做最小实现,不追求工业级完整性。
不展开的内容包括
- KV Cache
- FlashAttention
- mixed precision 优化
- 分布式并行
- 量化
- RoPE
- RMSNorm
- GQA / MQA
- MoE
- checkpointing
- compile / CUDA graph
这些东西当然都重要,但它们属于“把 Transformer 做快、做大、做稳”的阶段。
而这一篇聚焦的是:
先把经典 Decoder-only Transformer 的最小骨架亲手搭起来。
四、先给你最终完整代码
下面是一份单文件、可直接运行的最小实现。
你可以把它保存为:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
| import math from dataclasses import dataclass from typing import Optional
import torch import torch.nn as nn import torch.nn.functional as F
@dataclass class ModelConfig: vocab_size: int = 100 max_seq_len: int = 64 d_model: int = 128 n_heads: int = 4 n_layers: int = 4 d_ff: int = 256 dropout: float = 0.1
class MultiHeadSelfAttention(nn.Module): def __init__(self, config: ModelConfig): super().__init__() assert config.d_model % config.n_heads == 0, 'd_model 必须能被 n_heads 整除'
self.d_model = config.d_model self.n_heads = config.n_heads self.head_dim = config.d_model // config.n_heads
self.q_proj = nn.Linear(config.d_model, config.d_model) self.k_proj = nn.Linear(config.d_model, config.d_model) self.v_proj = nn.Linear(config.d_model, config.d_model) self.out_proj = nn.Linear(config.d_model, config.d_model)
self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout)
mask = torch.triu(torch.ones(config.max_seq_len, config.max_seq_len), diagonal=1).bool() self.register_buffer('causal_mask', mask, persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C = x.shape
q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x)
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) attn_scores = attn_scores.masked_fill(self.causal_mask[:T, :T], float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1) attn_weights = self.attn_dropout(attn_weights)
out = attn_weights @ v out = out.transpose(1, 2).contiguous().view(B, T, C) out = self.out_proj(out) out = self.resid_dropout(out) return out
class FeedForward(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.net = nn.Sequential( nn.Linear(config.d_model, config.d_ff), nn.GELU(), nn.Linear(config.d_ff, config.d_model), nn.Dropout(config.dropout), )
def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x)
class TransformerBlock(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.ln_1 = nn.LayerNorm(config.d_model) self.attn = MultiHeadSelfAttention(config) self.ln_2 = nn.LayerNorm(config.d_model) self.ffn = FeedForward(config)
def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.ln_1(x)) x = x + self.ffn(self.ln_2(x)) return x
class MiniTransformerLM(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.config = config
self.token_emb = nn.Embedding(config.vocab_size, config.d_model) self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model) self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) self.final_ln = nn.LayerNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head.weight = self.token_emb.weight
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None): B, T = idx.shape assert T <= self.config.max_seq_len, '输入长度超过 max_seq_len'
positions = torch.arange(0, T, device=idx.device).unsqueeze(0)
tok = self.token_emb(idx) pos = self.pos_emb(positions) x = tok + pos
for block in self.blocks: x = block(x)
x = self.final_ln(x) logits = self.lm_head(x)
loss = None if targets is not None: loss = F.cross_entropy( logits.view(B * T, self.config.vocab_size), targets.view(B * T) )
return logits, loss
@torch.no_grad() def generate(self, idx: torch.Tensor, max_new_tokens: int): for _ in range(max_new_tokens): idx_cond = idx[:, -self.config.max_seq_len:] logits, _ = self(idx_cond) next_token_logits = logits[:, -1, :] probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, next_token], dim=1) return idx
def make_toy_data(batch_size: int, seq_len: int, vocab_size: int, device: str): x = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device=device) inputs = x[:, :-1] targets = x[:, 1:] return inputs, targets
def main(): device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = ModelConfig( vocab_size=50, max_seq_len=32, d_model=64, n_heads=4, n_layers=2, d_ff=128, dropout=0.1, )
model = MiniTransformerLM(config).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
print('Model parameters:', sum(p.numel() for p in model.parameters()))
model.train() for step in range(50): inputs, targets = make_toy_data(batch_size=16, seq_len=17, vocab_size=config.vocab_size, device=device) _, loss = model(inputs, targets)
optimizer.zero_grad() loss.backward() optimizer.step()
if step % 10 == 0: print(f'step={step:03d} loss={loss.item():.4f}')
model.eval() start = torch.tensor([[1, 2, 3]], device=device) out = model.generate(start, max_new_tokens=10) print('Generated token ids:', out.tolist())
if __name__ == '__main__': main()
|
五、先从整体看代码结构
虽然这份代码不长,但它已经完整对应了一个最小 Decoder-only Transformer 的主要骨架。
结构可以压缩成:
ModelConfig:定义超参数
MultiHeadSelfAttention:实现 Causal Self-Attention
FeedForward:实现 FFN
TransformerBlock:把 attention + FFN 组织成 block
MiniTransformerLM:把 embedding、多个 block、输出头堆起来
generate():做最朴素的自回归生成
main():给一个最小训练和生成示例
接下来,我们按模块重新拆开看。
六、Embedding 在代码里到底怎么体现
看这两行:
1 2
| self.token_emb = nn.Embedding(config.vocab_size, config.d_model) self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
|
它们分别对应:
- token embedding
- position embedding
在 forward 里:
1 2 3
| tok = self.token_emb(idx) pos = self.pos_emb(positions) x = tok + pos
|
这就是我们前面博客里讲过的:
hi=xi+pi
也就是说:
token_emb 负责把 token id 映射成内容向量
pos_emb 负责给每个位置分配位置向量
- 两者相加,得到真正送进 Transformer block 的输入表示
这里我们用的是可学习位置 embedding,没有用 RoPE。
原因很简单:
这一篇的目标是把骨架写清楚,而不是追求现代大模型全部细节。
七、Multi-Head Self-Attention 在代码里是怎么落地的
先看最核心的三行:
1 2 3
| q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x)
|
这正对应前面原理里的:
Q=HWQ,K=HWK,V=HWV
区别只在于:
- 数学里我们直接写矩阵乘法
- PyTorch 里用
nn.Linear 帮你管理参数和乘法
然后是多头 reshape:
1 2 3
| q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
这一步的意义是:
把原来的 [B, T, C]
拆成
[B, n_heads, T, head_dim]
也就是把总通道维度切分成多个头,让每个头在自己的子空间里做 attention。
八、Attention score 和 causal mask 在代码里怎么对应公式
核心代码是:
1 2 3
| attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) attn_scores = attn_scores.masked_fill(self.causal_mask[:T, :T], float('-inf')) attn_weights = F.softmax(attn_scores, dim=-1)
|
它对应的数学过程就是:
S=dkQKT
然后加上 causal mask:
S^=S+M
再做:
A=softmax(S^)
这里你要特别注意 masked_fill 的作用。
我们之前讲过,causal mask 的本质是:
- 当前位置不能看未来位置
- 所以未来位置在 softmax 前必须被置成 (-\infty)
在代码里:
1
| self.causal_mask = torch.triu(torch.ones(...), diagonal=1).bool()
|
生成的是上三角 True 区域,表示:
所以这一段代码就是 Causal Self-Attention 的真正落地。
九、Value 聚合和输出投影在代码里怎么体现
接着看:
1 2 3
| out = attn_weights @ v out = out.transpose(1, 2).contiguous().view(B, T, C) out = self.out_proj(out)
|
这对应的数学过程是:
O=AV
然后把所有头拼接回来,再做一次输出投影:
MultiHead(H)=Concat(head1,…,headh)WO
也就是说,代码和公式是一一对应的:
attn_weights @ v:每个头内部做 value 聚合
transpose + view:把多头结果拼回主通道
out_proj:做最终输出融合
十、FFN 在代码里怎么对应前面的数学讲解
看这个模块:
1 2 3 4 5 6
| self.net = nn.Sequential( nn.Linear(config.d_model, config.d_ff), nn.GELU(), nn.Linear(config.d_ff, config.d_model), nn.Dropout(config.dropout), )
|
这正对应:
FFN(x)=W2σ(W1x+b1)+b2
只是这里激活函数用的是 GELU,而不是最简单的 ReLU。
这是现代 Transformer / LLM 里很常见的选择。
你可以从这段代码里直接看出我们前面讲的三个关键点:
- 先升维:
d_model -> d_ff
- 再非线性激活
- 再降维:
d_ff -> d_model
这说明 FFN 的本质,在代码里同样没有变化:
先把每个位置的表示展开到更高维空间,再做非线性加工,再压回主空间。
十一、为什么 block 用的是 Pre-LN
看 TransformerBlock:
1 2
| x = x + self.attn(self.ln_1(x)) x = x + self.ffn(self.ln_2(x))
|
这是非常标准的 Pre-LN Transformer Block。
对应数学写法:
h=x+Attention(LN(x))
y=h+FFN(LN(h))
为什么这样写?
因为现代很多大模型更偏好 Pre-LN,它在深层训练时通常更稳定。
而且这也和我们前面原理博客的解释完全一致:
- 先 LN 让子层看到稳定输入
- 子层输出作为修正量
- 再通过残差叠加回主表示
所以你现在应该能真正把:
连成一条线。
这个类本质上是在做三件事:
1. 处理输入
1 2 3
| tok = self.token_emb(idx) pos = self.pos_emb(positions) x = tok + pos
|
2. 通过多层 block 做深层加工
1 2
| for block in self.blocks: x = block(x)
|
3. 输出 logits
1 2
| x = self.final_ln(x) logits = self.lm_head(x)
|
下面就是一个最小 Decoder-only Transformer 的完整数据流:
token ids→embedding + pos→多个 Transformer blocks→final LN→logits
其中 lm_head 的作用,就是把隐藏状态重新投影回词表维度,得到:
zt∈RV
这正对应语言模型训练里每个位置的 next-token 分布。
十三、为什么这里用了权重共享
代码里有这句:
1
| self.lm_head.weight = self.token_emb.weight
|
这表示输出头和 token embedding 共享参数。
很多语言模型都会这样做。
为什么可行?
因为:
- 输入 embedding:把 token id 映射到向量空间
- 输出头:把向量空间映射回词表打分空间
它们在数学上有很强的对偶关系,所以共享参数往往既合理,又能减少参数量。
这一点不是理解 Transformer 骨架的必要条件,但它是一个很常见、也很有代表性的实现细节。
十四、训练时 forward 和 loss 是怎么对应前面理论的
看这一段:
1 2 3 4
| loss = F.cross_entropy( logits.view(B * T, self.config.vocab_size), targets.view(B * T) )
|
这对应前面原理里讲的:
- 每个位置输出一个词表大小维度 logits
- 把每个位置都看成一次分类
- 目标是预测“下一个 token”
这里之所以要 view(B * T, vocab_size),是因为交叉熵函数常常希望输入形状是:
而对语言模型来说:
- 每个 batch 中的每个位置,其实都是一个训练样本
所以把 [B, T, V] 展平成 [B*T, V] 是非常自然的。
这也正对应了 Teacher Forcing 训练的本质:
在完整已知序列上,并行优化所有位置的 next-token 预测。
十五、为什么 generate() 这么写,刚好对应前面讲过的推理本质
看生成函数:
1 2 3 4 5 6 7
| for _ in range(max_new_tokens): idx_cond = idx[:, -self.config.max_seq_len:] logits, _ = self(idx_cond) next_token_logits = logits[:, -1, :] probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, next_token], dim=1)
|
你应该已经能看懂,这几乎就是前面推理原理的代码翻译版:
- 拿当前前缀
- 跑一遍模型
- 取最后一个位置的 logits
- 得到下一个 token 的分布
- 采样一个 token
- 拼回输入
- 继续下一轮
这就是最朴素的自回归生成。
同时你也能一眼看出它的低效之处:
- 每一步都重新把整个前缀送进模型
- 没有 KV Cache
- 所以随着序列变长,会越来越慢
而这,恰恰对应了我们前面讲的:
为什么推理系统里必须引入 KV Cache。
所以这段代码虽然“笨”,但它非常有价值,因为它把原理和推理瓶颈暴露得非常清楚。
十六、这份最小实现和真实大模型实现相比,缺了什么
到这里你已经有了一个真正能跑的最小 Transformer。
但它距离工业级 LLM 当然还有很长的距离。
真实大模型一般还会进一步加入很多东西,例如:
结构层
- RoPE
- RMSNorm
- SwiGLU
- GQA / MQA
- MoE
推理层
- KV Cache
- FlashAttention
- PagedAttention
- prefix caching
系统层
- continuous batching
- tensor / pipeline / expert parallel
- PD 分离
- Distributed KV Cache
但这不影响这份代码的价值。
因为它让你先抓住了那个最重要的骨架:
如果骨架不清楚,后面所有优化都会变成“记工程名词”;如果骨架清楚,后面所有优化都能顺着结构自然长出来。
十七、你现在应该如何使用这份代码
这份代码的正确用法,不只是“复制运行”。
更重要的是你可以拿它做下面这些练习。
1. 改小模型尺寸,打印每层 shape
例如在 forward() 里加 print,看每个张量的维度变化。
2. 去掉 Multi-Head,只保留单头
观察代码如何退化成最基础的 attention。
3. 把 GELU 换成 ReLU
体会 FFN 激活函数的替换。
4. 把可学习位置 embedding 换成 sinusoidal position encoding
把这一篇和前面位置编码原理真正连接起来。
5. 给 generate() 加 KV Cache
这是从“原理代码”迈向“推理工程代码”的最好练习之一。
十八、从这篇代码里,你应该真正学到的不是“会抄实现”,而是“会对结构做映射”
如果你只是抄一遍这份代码,它的价值很有限。
这篇真正想让你建立的能力是:
看到数学公式时,能想到对应代码;
看到代码实现时,能反推它对应的结构功能。
例如:
- 看到
q_proj/k_proj/v_proj,你应该想到 Q/K/V 投影
- 看到
masked_fill(..., -inf),你应该想到 causal mask
- 看到
x + self.attn(...),你应该想到残差连接
- 看到
nn.LayerNorm(...),你应该想到 Pre-LN / Post-LN 的组织方式
- 看到
generate() 的循环,你应该想到自回归 decode 的串行本质
这才是真正从“看懂博客”走向“能自己写模型”的关键一步。
十九、本篇真正要记住的五条主线
- embedding
- multi-head Causal Self-Attention
- FFN
- Transformer Block 堆叠
- 输出 logits
第二条:代码和数学公式是一一对应的
q_proj/k_proj/v_proj 对应 Q/K/V
q @ k.transpose(...) / sqrt(d) 对应 score
softmax 对应注意力权重
attn_weights @ v 对应 value 聚合
第三条:block 的组织方式
- Pre-LN
- attention 残差
- FFN 残差
第四条:训练对应 Teacher Forcing
- 完整前缀已知
- 每个位置都参与 next-token loss
第五条:朴素 generate() 会暴露推理的真正瓶颈
- 每一步都重算前缀
- 没有 KV Cache
- 因此天然低效
二十、用一句话压缩本篇
用 PyTorch 从零实现一个最小 Decoder-only Transformer 的关键,不是记住某段现成代码,而是把 embedding、Q/K/V、causal attention、FFN、残差、LayerNorm 和自回归生成这些结构原理,一一映射成可运行的张量计算过程。
二十一、下一篇最自然的衔接
如果继续写下去,最自然的下一篇是:
因为到这里,你已经真正掌握了“经典骨架”。
下一步最值得做的,就是回答:
- 为什么今天的大模型已经不完全长得像 2017 年论文里的 Transformer
- RoPE 到底替代了什么
- RMSNorm 和 LayerNorm 有什么区别
- GQA / MQA 为什么会和 KV Cache 强相关
- SwiGLU 为什么比普通 FFN 更常见
- MoE 为什么会把问题从模型设计推向系统与通信