混合精度和显存计算

混合精度计算

在深度学习中,常用混合精度计算来加速模型训练。举例来说,线性层或者卷积层在使用float16计算时速度要比使用float32快得多,而且这些算子对精度要求没有那么高;减法操作等算子则对精度有较高要求,一般需要float32精度。PyTorch提供自动混合精度(torch.cuda.amp)方法,尝试将每个操作和适当的数值类型匹配,以此减少模型的内存占用和训练时长。

数据类型 符号占位 指数占位 尾数占位
fp16 1 5 10
bf16 1 8 7
fp32 1 8 23

一般来说有以上三种浮点数类型,其中指数占位代表了浮点数表示的范围,尾数占位代表了浮点数表示的精度

当然,使用混合精度计算也不是没有缺点,需要搭配一些优化策略:

  • 梯度下溢
    在训练过程中,特别是在使用低精度浮点数(如float16)时,梯度可能会变得非常小,导致数值不稳定性,这种现象称为梯度下溢
  • 梯度缩放
    对于使用float16梯度的网络,梯度缩放在反向传播之前放大梯度,减少梯度下溢的问题,从而帮助网络更好地收敛

混合精度计算的优缺点:

  • 优点
    • 混合精度训练可以提高训练速度,减少内存使用
    • 同时使用fp32进行关键的计算,以保证准确性
  • 缺点
    • 混合精度训练可能会导致数值不稳定,特别是在模型梯度非常小或非常大的时候,还需要额外的校准步骤来确保fp16计算的准确性

训练显存

在模型训练时,除了存储模型权重需要占用显存,存储优化器状态、计算梯度也需要占用显存。所以训练占用的显存要比推理时更多。

模型权重

如果全部使用fp32精度的浮点数类型,那么显存占用和模型参数存在这样的关系:

如果全部使用fp16精度的浮点数类型,那么显存占用和模型参数存在这样的关系:

而如果使用混合精度,显存占用和模型参数的关系如下:

还要再加上优化器状态中的fp32版本的模型参数,只不过不放在这里计算,放到优化器状态中计算。

梯度

梯度既可以使用fp32类型也可以使用fp16类型,不过一般和模型的数据类型一致。

所以如果是fp32类型,显存占用关系如下:

如果是fp16类型,显存占用关系如下:

优化器状态

在优化器更新模型参数时,会使用float32类型的优化器状态、float32类型的梯度和float32类型的模型参数来更新。

选择不同的优化器,对显存的占用是不一样的:

  • AdamW:$Memory_{optimizer} = (16 \, bytes / param) * (No.params)$

    • fp32版本的模型参数:$4 \, bytes / param$
    • fp32版本的梯度:$4 \, bytes / param$
    • 一阶动量:$4 \, bytes / param$
    • 二阶动量:$4 \, bytes / param$
  • 8-bit的优化器,如bitsandbytes:$Memory_{optimizer} = (10 \, bytes / param) * (No.params)$

    • fp32版本的模型参数:$4 \, bytes / param$
    • fp32版本的梯度:$4 \, bytes / param$
    • 一阶动量:$1 \, bytes / param$
    • 二阶动量:$1 \, bytes / param$
  • 带动量的SGD优化器:$Memory_{optimizer} = (12 \, bytes / param) * (No.params)$

    • fp32版本的模型参数:$4 \, bytes / param$
    • fp32版本的梯度:$4 \, bytes / param$
    • 动量:$4 \, bytes / param$

当然在实际训练时,还需要存储中间激活值以便在反向传播计算梯度时使用。这里的激活指的是前向传播过程中计算得到,并在反向传播过程中需要用到的所有张量。这里的激活不包含模型参数和优化器状态,但是包含了Dropout操作需要用到的mask矩阵。中间激活值和输入数据的大小(批次大小和序列长度)是成正相关的。

当我们训练神经网络遇到显存不足(OOM,Out Of Memory)问题时,通常会尝试减小批次大小来避免显存不足的问题,这种方式减少的其实是中间激活占用的显存,而不是模型参数、梯度和优化器的显存。

总得来说,训练时需要的显存来自多个方面:

对于大模型来说,需要的显存实在是太大,即使是一张A100(80G显存)也不可能存放得下所有参数。因此DeepSpeed提出了ZeRO的分布式训练方法:

  • ZeRO-1
  • ZeRO-2
  • ZeRO-3

推理显存

神经网络推理阶段,没有优化器状态和梯度,也不需要保存中间激活,此时占用的显存要远小于训练阶段。

在模型推理时,显存占用分为两部分:1)存放模型的权重和前向传播时需要的一小部分额外参数,如输入数据需要放在GPU上;2)存放模型的KV Cache,这部分不仅依赖模型的参数大小,还依赖模型的输入长度,会在推理过程中动态增长。当context长度足够长时,第二部分的显存大小就会占据主导地位,导致Out Of Memory问题。

典型的大模型推断包含了两个阶段:

  1. 预填充阶段:输入一个prompt,为每个transformer层生成key cache和value cache(KV Cache)
  2. 解码阶段:使用并更新KV cache,一个接一个地生成token

在GPU上部署模型的原则是:能一张卡部署,就不要跨多张卡;能一台机器部署,就不要跨多台机。这是因为卡内通信带宽 > 卡间通信带宽 > 机间通信带宽,模型部署时跨的设备越多,受设备通信带宽的拖累就越大。

所以,减少KV Cache可以实现在更少的设备上推理更长的context,或者在相同的context长度下让推理的batch size更大,从而实现更快的推理速度或者更大的吞吐量。FlashAttention就是优化了KV Cache,从而实现了更低的推理成本。

【参考文献】

  1. Transformer 101 Math
  2. Transformer Inference Arithmetic
  3. PyTorch Document
  4. 分析transformer模型的参数量、计算量、中间激活、KV cache