传统的Attention存在以下问题:
- 存在上下文长度约束问题
- 计算速度慢,内存占用大
因此可以从提升上下文长度和加速、减少内存占用两个方向优化Attention。
Attention的改进措施有:
1. 稀疏注意力
标准注意力机制中,每个输出元素都会考虑所有输入元素,这导致了计算复杂度和内存需求的平方级增长。对于非常长的序列,这种全注意力方式不太现实。
稀疏注意力的核心思想是,不是每个输出元素都依赖于所有输入元素,而是只依赖于输入序列的一个子集。这样,可以显著减少需要计算的注意力权重数量,从而降低计算复杂度和内存需求。
稀疏注意力的常见形式包括:
- 固定窗口注意力,每个输出元素只关注其对应的窗口内的输入元素
- 可学习的稀疏注意力,模型通过学习来确定哪些输入元素对输出元素是重要的
- 多尺度注意力,结合了不同尺度的稀疏注意力模式,例如同时使用局部窗口和全局注意力
2. 线性注意力
线性注意力是按照$Q(K^T V)$的顺序计算注意力,如果$n$为序列长度,$d$为head_size
,那么标准Attention每个头的计算量是$2n^2d$,线性Attention每个头的计算量是$2nd^2$。如果$n>d$,那么线性Attention是比标准Attention省计算量的。
但是线性Attention有着比标准Attention更严重的低秩瓶颈(可参考苏剑林博客),所以如果切换到线性Attention后还用同一个$d$,那么线性Attention的效果会明显下降。如果要保留大致相同的效果,那么线性Attention需要用更大的$d$(一般是原来的4倍左右)。
3. 改进多头机制
该系列研究探索了不同的替代多头注意力的机制,有Multi-Query Attention,Grouped-Query Attention和Multi-head Latent Attention。
MHA
在Transformer详解中,我们已经介绍了Multi-Head Attention,这里不再赘述。
如果只考虑主流的自回归LLM所使用的Causal Attention,在token by token递归生成时,新预测出来的第$t+1$个token并不会影响已经计算好的$\mathbf{k}_{\leq t}$和$\mathbf{v}_{\leq t}$。因此这部分结果可以缓存下来供后续生成调用,避免不必要的重复计算,这就是所谓的KV Cache。在推理过程中,反复加载巨大的KV Cache会导致内存开销过大。
后面发展出来的MQA、GQA和MLA,都是围绕如何减少KV Cache同时尽可能地保证效果这个主题发展而来的产物。
MQA
MQA的全称是Multi-Query Attention,它是减少KV Cache的一次非常朴素的尝试。
MQA的思路非常简单,直接让所有的Attention Head共享同一个$K$和$V$,如果有$h$个Head,MQA直接将KV Cache减少到了原来的$1/h$。MQA减少了显存占用,提升了推理速度。
目前来看,大部分任务的效果损失比较有限,而且MQA的支持者相信这部分损失可以通过进一步训练来弥补。此外,由于MQA共享了$K$和$V$,这会导致Attention的参数量减少将近一半,而为了模型总参数量不变,通常会相应地增大FFN/GLU的规模,这也能弥补一部分效果损失。
GQA
有人担心MQA对KV Cache的压缩太严重,以至于会影响模型的学习效率以及最终效果,提出了MHA和MQA之间的过渡版本GQA(Grouped-Query Attention)。
GQA的思想也很朴素,它就是将所有的Head分为$g$个组,每组共享同一对$K$和$V$。
Meta开源的LLaMA2-70B,LLaMA3全系列,DeepSeek-V1、Yi、ChatGLM系列等知名大模型都使用了GQA。
LLaMA2-70B中,GQA的$g$设置为8,其它使用了GQA的同体量模型基本也保持了这个设置。我们知道,70B这个体量的模型,如果不进行极端的量化,那么不可能部署到单卡上。单卡不行,那么就只能单机了。一般情况下,一台机器可以装8张卡,而Attention的每个Head实际上是独立运算然后拼接起来的,当$g=8$时正好可以每张卡负责计算一组$K$和$V$对应的Attention Head,这样可以尽可能保证$K$和$V$多样性的同时最大程度减少卡间通信。

Flash Attention
此外,Flash Attention利用分块Softmax等价替代传统Softmax,可以节约HBM,高效利用SRAM,节省显存,提升速度。
Meta推出的LLaMA就使用了Flash Attention来加速计算和节省显存。
并行Transformer Block
用并行公式替换了串行,提升了15%的训练速度。
在8B参数量规模,会有轻微的模型效果损失;在62B参数量规模,就不会损失模型效果。
Falcon和PaLM都使用了该技术来加速训练。
【参考文献】