首页 > 教程攻略 > ai教程 >Transformer原理讲解

Transformer原理讲解

来源:互联网 时间:2026-06-19 07:27:07

1. 架构图

先看架构图,有个整体印象。图中展示了Transformer的基本结构——输入经过嵌入和位置编码,再通过编码器-解码器堆叠,最后输出预测。下面逐层拆解,把每个模块的来龙去脉讲清楚。

1.1 输入的嵌入

Input Embedding 的作用是把离散的文本转换成连续的向量表示,这样神经网络才能处理文本信息。如果直接用 one-hot 编码,维度会非常高,模型算不动。解决方案就是引入一个词嵌入矩阵,把每个词映射到低维稠密向量。

1.2 位置编码

经过 word embedding 后,我们得到了词与词之间关系的表示,但词在句子中的位置信息还没体现出来。Transformer 是并行处理所有词的,不会天然知道词的先后顺序,所以必须把位置信息加进去。加上位置编码后的嵌入向量就是 Position Embedding。比如图中love这个词,位置为 pos=1,会通过正余弦函数计算出对应的位置向量。

2. q、k、v 是什么?

在自注意力计算中,第一步是把编码器的每个输入向量(词的特征表示)通过线性变换映射成三个新向量:查询向量(Query)、键向量(Key)和值向量(Value),简称 q、k、v。

那这三个东西到底代表什么?我们可以用查询的思路来理解:

  • 如果当前要查询“吃”这个词,那么“吃”就是 Query(查询),它旁边的其他词就是 Key(键)。
  • 计算相似度时,用 Query 点乘 Key 的转置,得到相关性分数。为什么要算这个?因为我们需要知道“吃”和哪些词关系最密切——比如“面条”就和“吃”密切相关,所以会给“面条”更高的权重。
  • Value 就是具体的值,比如“我”“今天”这些词本身的向量表示。

3. 注意力计算过程

具体分几步走:

  1. 计算相关性分数:第 i 个位置的 Query 和每个位置(包括自己)的 Key 做点积,得到一组分数。
  2. 把分数除以 8(论文中 Query 向量维度的平方根,即根号下 64)。这一步类似归一化,目的是让训练时的梯度更稳定。
  3. 经过 Softmax 得到权重因子,让所有位置的权重之和为 1。
  4. 把每个 Value 向量乘以对应的 Softmax 分数,然后加权求和,得到当前词的自注意力输出。

4. Self-Attention 公式

实际编码时不会逐个向量算,而是把输入打包成矩阵,用矩阵乘法一次性完成所有计算,效率更高。

5. 多头注意力机制

单一注意力有时候会“偏听偏信”,所以 Transformer 用了多个注意力头——每个头独立计算注意力,然后拼接在一起,再经过一个全连接层融合。这样模型可以从不同子空间关注不同角度的信息。

6. 解码层

解码层和编码层的结构类似,但多了一个交叉注意力模块。解码层内部顺序是:Masked 自注意力 → 交叉注意力(与编码器输出交互) → 前馈神经网络。交叉注意力让解码器能够“看到”源序列的上下文,从而生成目标序列。

7. Mask 多头注意力

在解码器的 self-attention 里,需要用 mask 把未来的词遮挡住——因为在生成第 t 个词时,模型只能看到前 t 个词,不能提前看到后面的词。这个 mask 是一个上三角矩阵,对角线上方的位置被设为负无穷,Softmax 之后权重就变成了 0。

8. 代码实现

以下是基于 PyTorch 的完整 Transformer 实现,包含自注意力、多头注意力、前馈网络、位置编码以及编码器和解码器。关键部分都加了注释,方便理解。

import torch
import torch.nn as nn
import math

# 定义自注意力
class SelfAttentionn(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, Q, K, V, mask=None):
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = self.softmax(scores)
        attn = self.dropout(attn)
        out = torch.matmul(attn, V)
        return out, attn

# 定义多头注意力
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_k = d_model // n_heads
        self.n_heads = n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.fc = nn.Linear(d_model, d_model)
        self.attention = SelfAttentionn(dropout)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        Q = self.W_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(v).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        out, attn = self.attention(Q, K, V, mask)
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_k)
        out = self.fc(out)
        out = self.dropout(out)
        return self.norm(out + q), attn

# 定义前馈网络
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)
    def forward(self, x):
        out = self.fc2(self.dropout(torch.relu(self.fc1(x))))
        return self.norm(out + x)

# 编码器层
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
    def forward(self, src, src_mask=None):
        out, _ = self.self_attn(src, src, src, src_mask)
        out = self.ffn(out)
        return out

# 解码器层
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        out, _ = self.self_attn(tgt, tgt, tgt, tgt_mask)
        out, _ = self.cross_attn(out, memory, memory, memory_mask)
        out = self.ffn(out)
        return out

# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

# 编码器
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, num_layers, d_ff, dropout=0.1, max_len=5000):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
    def forward(self, src, src_mask=None):
        out = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)
        out = self.pos_encoding(out)
        for layer in self.layers:
            out = layer(out, src_mask)
        return out

# 解码器
class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, num_layers, d_ff, dropout=0.1, max_len=5000):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        out = self.embedding(tgt) * math.sqrt(self.embedding.embedding_dim)
        out = self.pos_encoding(out)
        for layer in self.layers:
            out = layer(out, memory, tgt_mask, memory_mask)
        return self.fc_out(out)

# 完整Transformer
class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=512, n_heads=8,
                 num_encoder_layers=6, num_decoder_layers=6, d_ff=2048, dropout=0.1, max_len=5000):
        super().__init__()
        self.encoder = Encoder(src_vocab, d_model, n_heads, num_encoder_layers, d_ff, dropout, max_len)
        self.decoder = Decoder(tgt_vocab, d_model, n_heads, num_decoder_layers, d_ff, dropout, max_len)
    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        memory = self.encoder(src, src_mask)
        out = self.decoder(tgt, memory, tgt_mask, memory_mask)
        return out

# 生成mask(防止看到未来token)
def generate_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return mask == 0

# 测试
src_vocab = 10000
tgt_vocab = 10000
model = Transformer(src_vocab, tgt_vocab)
src = torch.randint(0, src_vocab, (32, 10))
tgt = torch.randint(0, tgt_vocab, (32, 20))
tgt_mask = generate_mask(tgt.size(1)).to(tgt.device)
out = model(src, tgt, tgt_mask=tgt_mask)
print(out.shape)  # torch.Size([32, 20, 10000])