Attention进阶

传统的Attention存在以下问题:

  • 存在上下文长度约束问题
  • 计算速度慢内存占用大

因此可以从提升上下文长度加速、减少内存占用两个方向优化Attention。

Attention的改进措施有:

1. 稀疏注意力

标准注意力机制中,每个输出元素都会考虑所有输入元素,这导致了计算复杂度和内存需求的平方级增长。对于非常长的序列,这种全注意力方式不太现实。

稀疏注意力的核心思想是,不是每个输出元素都依赖于所有输入元素,而是只依赖于输入序列的一个子集。这样,可以显著减少需要计算的注意力权重数量,从而降低计算复杂度和内存需求。

稀疏注意力的常见形式包括:

  • 固定窗口注意力,每个输出元素只关注其对应的窗口内的输入元素
  • 可学习的稀疏注意力,模型通过学习来确定哪些输入元素对输出元素是重要的
  • 多尺度注意力,结合了不同尺度的稀疏注意力模式,例如同时使用局部窗口和全局注意力

2. 线性注意力

线性注意力是按照$Q(K^T V)$的顺序计算注意力,如果$n$为序列长度,$d$为head_size,那么标准Attention每个头的计算量是$2n^2d$,线性Attention每个头的计算量是$2nd^2$。如果$n>d$,那么线性Attention是比标准Attention省计算量的。

但是线性Attention有着比标准Attention更严重的低秩瓶颈(可参考苏剑林博客),所以如果切换到线性Attention后还用同一个$d$,那么线性Attention的效果会明显下降。如果要保留大致相同的效果,那么线性Attention需要用更大的$d$(一般是原来的4倍左右)。

3. 改进多头机制

该系列研究探索了不同的替代多头注意力的机制,有Multi-Query Attention,Grouped-Query Attention和Multi-head Latent Attention。

MHA

Transformer详解中,我们已经介绍了Multi-Head Attention,这里不再赘述。

如果只考虑主流的自回归LLM所使用的Causal Attention,在token by token递归生成时,新预测出来的第$t+1$个token并不会影响已经计算好的$\mathbf{k}_{\leq t}$和$\mathbf{v}_{\leq t}$。因此这部分结果可以缓存下来供后续生成调用,避免不必要的重复计算,这就是所谓的KV Cache。在推理过程中,反复加载巨大的KV Cache会导致内存开销过大。

后面发展出来的MQA、GQA和MLA,都是围绕如何减少KV Cache同时尽可能地保证效果这个主题发展而来的产物。

MQA

MQA的全称是Multi-Query Attention,它是减少KV Cache的一次非常朴素的尝试。

MQA的思路非常简单,直接让所有的Attention Head共享同一个$K$和$V$,如果有$h$个Head,MQA直接将KV Cache减少到了原来的$1/h$。MQA减少了显存占用,提升了推理速度。

目前来看,大部分任务的效果损失比较有限,而且MQA的支持者相信这部分损失可以通过进一步训练来弥补。此外,由于MQA共享了$K$和$V$,这会导致Attention的参数量减少将近一半,而为了模型总参数量不变,通常会相应地增大FFN/GLU的规模,这也能弥补一部分效果损失。

GQA

有人担心MQA对KV Cache的压缩太严重,以至于会影响模型的学习效率以及最终效果,提出了MHA和MQA之间的过渡版本GQA(Grouped-Query Attention)。

GQA的思想也很朴素,它就是将所有的Head分为$g$个组,每组共享同一对$K$和$V$。

Meta开源的LLaMA2-70B,LLaMA3全系列,DeepSeek-V1、Yi、ChatGLM系列等知名大模型都使用了GQA。

LLaMA2-70B中,GQA的$g$设置为8,其它使用了GQA的同体量模型基本也保持了这个设置。我们知道,70B这个体量的模型,如果不进行极端的量化,那么不可能部署到单卡上。单卡不行,那么就只能单机了。一般情况下,一台机器可以装8张卡,而Attention的每个Head实际上是独立运算然后拼接起来的,当$g=8$时正好可以每张卡负责计算一组$K$和$V$对应的Attention Head,这样可以尽可能保证$K$和$V$多样性的同时最大程度减少卡间通信。


Flash Attention

此外,Flash Attention利用分块Softmax等价替代传统Softmax,可以节约HBM,高效利用SRAM,节省显存,提升速度。

Meta推出的LLaMA就使用了Flash Attention来加速计算和节省显存。

并行Transformer Block

用并行公式替换了串行,提升了15%的训练速度。

在8B参数量规模,会有轻微的模型效果损失;在62B参数量规模,就不会损失模型效果。

Falcon和PaLM都使用了该技术来加速训练。

【参考文献】

  1. 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA
  2. GQA vs MHA
  3. 线性Attention的探索:Attention必须有个Softmax吗?
  4. 线性Transformer应该不是你要等的那个模型