KV-Runahead Scalable Causal LLM Inference by Parallel Key-Value Cache Generation
Skimming Author Info Background Challenges Insights Approaches 看了好几遍都没看懂,我大概的理解是 利用了 casual mask 的特性以链式的方式在不同设备之间传递 KV,避免了传统 TSP 的大量重复计算和冗余传输 为了平衡整个流水线采用了 context-level load balancing,靠前的设备多算一些 KV, 靠后的设备少算一些,因为靠后的设备注意力计算会更长 这里的关键点是:每个设备不仅传递 KV 缓存,也要利用收到的缓存,完成自己那部分词元的注意力计算。 在 D1 上: 计算 T1-T4 的Q_0, K_0, V_0。 立刻进行自己部分的注意力计算:用 Q_0 与 K_0 计算一个 4x4 的注意力矩阵,得到输出A_0。 然后,它将 K_0, V_0(尺寸为 4 的缓存)发送给D2。 在 D2 上: 在等待 D1 数据的同时,它可以并行计算 T5-T7 的本地Q_1, K_1, V_1。 当它收到 D1 发来的 K_0, V_0 后,它将自己本地的 K_1, V_1 追加上去,形成一个包含 T1-T7 信息的、尺寸为 7 的 KV 缓存。 立刻进行自己部分的注意力计算:用自己的 Q_1(来自 T5-T7)与这个尺寸为 7 的完整缓存进行计算(一个 3x7 的注意力计算),得到输出 A_1。 然后,它将这个尺寸为 7 的 KV 缓存发送给 D3。 在 D3 上: 并行计算 T8-T9 的本地Q_2, K_2, V_2。 收到 D2 发来的尺寸为 7 的缓存后,追加自己的 K_2, V_2,形成包含全部 9 个词元信息的最终KV缓存。 它进行自己部分的注意力计算:用 Q_2 与这个尺寸为 9 的完整缓存进行计算(一个 2x9 的注意力计算),得到输出 A_2。 作为最后一个设备,它最终生成第一个令牌。 TSP ...