KV-Runahead Scalable Causal LLM Inference by Parallel Key-Value Cache Generation
Skimming Author Info Background Challenges Insights Approaches 看了好几遍都没看懂,我大概的理解是 利用了 casual mask 的特性以链式的方式在不同设备之间传递 KV,避免了传统 TSP 的大量重复计算和冗余传输 为了平衡整个流水线采用了 context-level load balancing,靠前的设备多算一些 KV, 靠后的设备少算一些,因为靠后的设备注意力计算会更长 这里的关键点是:每个设备不仅传递KV缓存,也要利用收到的缓存,完成自己那部分词元的注意力计算。 在 D1 上: 计算T1-T4的Q_0, K_0, V_0。 立刻进行自己部分的注意力计算:用Q_0与K_0计算一个4x4的注意力矩阵,得到输出A_0。 然后,它将K_0, V_0(尺寸为4的缓存)发送给D2。 在 D2 上: 在等待D1数据的同时,它可以并行计算T5-T7的本地Q_1, K_1, V_1。 当它收到D1发来的K_0, V_0后,它将自己本地的K_1, V_1追加上去,形成一个包含T1-T7信息的、尺寸为7的KV缓存。 立刻进行自己部分的注意力计算:用自己的Q_1(来自T5-T7)与这个尺寸为7的完整缓存进行计算(一个3x7的注意力计算),得到输出A_1。 然后,它将这个尺寸为7的KV缓存发送给D3。 在 D3 上: 并行计算T8-T9的本地Q_2, K_2, V_2。 收到D2发来的尺寸为7的缓存后,追加自己的K_2, V_2,形成包含全部9个词元信息的最终KV缓存。 它进行自己部分的注意力计算:用Q_2与这个尺寸为9的完整缓存进行计算(一个2x9的注意力计算),得到输出A_2。 作为最后一个设备,它最终生成第一个令牌。 TSP TSP方案旨在并行化标准的单设备注意力计算流程。其标准工作流程如图4所示: 执行流程: 将输入上下文在多个处理器(进程)间均匀划分。 每个进程独立计算其对应部分的Q, K, V向量。 通过一次**all-gather**集体通信操作,所有进程交换并获得完整的K和V向量。 每个进程使用其本地的Q和全局的K, V来计算均等分配的注意力输出部分。 技术瓶颈: 计算冗余: TSP忠实地并行化了通用的注意力计算方法,即先计算一个完整的稠密$QK^T$矩阵,然后再通过掩码(mask)忽略掉上三角部分。这个过程没有利用LLM的因果性,导致了大量的无效计算,造成了计算资源的浪费。 通信瓶颈: all-gather操作是一个全局同步点,要求所有进程必须互相等待,这会成为性能瓶颈。同时,它导致了很高的网络流量;在论文的示例中,TSP需要交换36个(K, V)条目。 KVR ...