Transformer基础知识

Transformer 是构成现有大模型的基础架构,理解它对于理解大模型的工作原理有着重要的意义。
Transformer 是完全基于自注意力机制的序列建模模型。
其核心特点是:

  • 抛弃了RNN,LSTM的串行结构
  • 完全并行计算,训练更快
  • 通过自注意力机制捕捉长距离依赖关系

结构简析

论文中的Transformer模型由编码器和解码器两部分组成,每部分由多个相同的层堆叠而成。

Transformer结构图
图片源自:知乎文章

1
2
3
输入序列 → Embedding + 位置编码 → Encoder 堆叠 N 层

Encoder 输出 → Decoder 堆叠 N 层 → 线性层 + Softmax → 输出序列

编码器

每个编码器层包含:

  • 多头自注意力(Multi-Head Self-Attention)
  • 残差连接 + LayerNorm
  • 前馈网络(Feed Forward Network, FFN)
  • 残差连接 + LayerNorm
    Transformer结构图
1
2
LayerNorm(x + SelfAttention(x))
→ LayerNorm(x + FFN(x))

解码器

Decoder比Encoder多了一个交叉注意力。(Encoder-Decoder Attention)
整体包括:

  • 掩码多头自注意力(Masked Multi-Head Self-Attention)
  • 残差 + LayerNorm
  • 交叉注意力(Encoder-Decoder Attention)
  • 残差 + LayerNorm
  • FFN
  • 残差 + LayerNorm

自注意力机制

  1. 对输入的embedding进行线性变换得到Query, Key, Value。(Q, K, V)
  2. 计算注意力权重
1
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
  1. 除以 sqrt(d_k) 是为了防止维度太大,内积结果过大导致softmax函数梯度消失。
  2. 其作用是每个词都能关注序列中的所有其他词,计算权重
  3. 进一步叙述
    Query(查询),表示我这个位置需要找什么
    Key(键),表示这个位置有什么,是一个检索的标签
    Value(值),表示这个位置的具体内容
    所以计算注意力权重的过程就是:用Q去匹配所有的K,计算出相似度,按照相似度加权取出V。最后得到了融合全局信息的新特征。
  4. 优势
  • 直接捕捉长距离依赖关系
  • 并行计算,训练更快(RNN需要逐步计算)
  • 可解释性较强,注意力热力图可以展示模型关注的词
  • 泛化能力强。无论句子长短,结构变化
  • 灵活建模关联,适合各种任务的特征

多头注意力

  1. 什么是多头注意力
  • 将Q, K, V用不同的线性层映射h次,得到h组Q, K, V
  • 每组Q, K, V计算注意力,得到h个不同的输出
  • 将h个输出拼接起来,再通过线性层得到最终输出
  1. 为什么要多头注意力
  • 设计初衷:让模型在不同 “子空间” 里学习多种类型的注意力关系,捕捉更丰富的特征。
  • 在实现后,大家发现。多头注意力还能够在工程上带来一些好处。多个头可以并行计算,提升效率;多个头的梯度会更稳定;多头能够更加聚焦自己的子空间

其余模块

  1. 掩码
    在解码器中,为了保证生成的文本是自回归的(即每个位置只能关注之前的位置),使用 MASK 来防止模型在计算注意力时访问未来的信息。
  2. 前馈网络
1
x → Linear → GELU/ReLU → Linear → output

一般来说,是先把维度升高(比如512升到2048),再降回原来的维度。
如果说注意力机制是“信息交互”,FFN就是“特征提纯和非线性变换”。
3. 位置编码
引入位置信息。Transformer论文是使用的正弦余弦函数来编码位置的。后续的模型也有使用learnable position embedding(可学习的位置嵌入)的,例如BERT,GPT。
4. 残差连接 + LayerNorm
层归一化:对单个样本的特征维度进行归一化,稳定训练(拉到均值为0,方差为1的分布)
残差连接:跳过当前层的输入直接加到输出上,缓解梯度消失问题,促进信息流动。

其他问题

  1. GPT等现代大模型为什么使用Decoder-only架构?
    Decoder擅长“生成”任务,能够更好地建模文本的自回归特性。它的任务是“根据前面的文本生成下一个词”,非常适合语言模型的训练目标。

改进

  • 稀疏注意力(Longformer、Performer)
  • 相对位置编码
  • 预训练 + 微调范式
  • FlashAttention 优化速度