论文

Flash Decoding == FlashAttenion-V3

<aside> 💡

在序列维度上引入并行计算(注意力);

long-context inference**;**

</aside>

Method [1]

kbgmg49g.bmp

Flash Decoding主要包含以下三个步骤(可以结合上图来看):

  1. 将keys和values分成较小的block
  2. 使用FlashAttention并行计算query与每个block的注意力(这是和FlashAttention最大的区别)。对于每个block的每行(因为一行是一个特征维度),Flash Decoding会额外记录attention values的log-sum-exp(标量值,用于第3步进行rescale)
  3. 对所有output blocks进行reduction得到最终的output,需要用log-sum-exp值来重新调整每个块的贡献