MInference Background

MInference

Million-Tokens Prompt Inference for Long-context LLMs

What is MInference?

MInference 1.0 leverages the dynamic sparse nature of LLMs’ attention, which exhibits some static patterns, to speed up the pre-filling for long-context LLMs. It first determines offline which sparse pattern each head belongs to, then approximates the sparse index online and dynamically computes attention with the optimal custom kernels. This approach achieves up to a 10x speedup for pre-filling on an A100 while maintaining accuracy.

MInference 1.0 will present at NeurIPS’24 as spotlight, and was presented at the Microsoft booth and ES-FoMo at ICML’24.


MInference slide onepager: MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention

Insights

  1. Attention, especially in long-context scenarios, is sparse and dynamic, i.e., the sparse patterns are largely different across inputs.
  2. This dynamic sparsity presents three unique spatial aggregation patterns that persist for all inputs: A-shapeVertical-Slash, and Block-Sparse.
  3. These dynamic sparse indices can be approximated with minimal overhead online and speed up attention inference using a custom optimized GPU kernel.

Why MInference?

Long-context LLM inference faces two major challenges1) long pre-filling stage attention latency, and 2) high storage and transfer costs for KV cache. Previous efficient methods for long-context LLMs have focused on KV-cache compression, static sparse attention (e.g., model compression, SSM, linear attention), or distributed serving. However, these methods struggle to achieve acceptable latency for million-token level prompts with low cost and a single A100 GPU.

MInference: Figure 3.1. (a) Visualization of attention weights from different attention heads. For different prompts and tasks, the pattern of the same head is relatively consistent, but the sparse indices are dynamically changing. (b) Distance of the top-10 nearest non-zero element in the attention matrix. (c) Attention recall distribution using our identified patterns, where FLOPs in the kernel refer to the real FLOPs required for sparse attention computing using on GPUs. Here, a 1x64 block size is used for the Vertical-Slash pattern, and a 64x64 block size is used for others on GPUs. All visualizations are based on LLaMA-3-8B-Instruct-262K.
Figure 3.1. (a) Visualization of attention weights from different attention heads. For different prompts and tasks, the pattern of the same head is relatively consistent, but the sparse indices are dynamically changing. (b) Distance of the top-10 nearest non-zero element in the attention matrix. (c) Attention recall distribution using our identified patterns, where FLOPs in the kernel refer to the real FLOPs required for sparse attention computing using on GPUs. Here, a 1×64 block size is used for the Vertical-Slash pattern, and a 64×64 block size is used for others on GPUs. All visualizations are based on LLaMA-3-8B-Instruct-262K.

To address these issues, we propose MInference, where the name reflects our ambition to enable million-token inference on a single A100 machine. MInference is a training-free efficient method for the pre-filling stage of long-context LLMs based on dynamic sparse attention. Specifically, we leverage the static spatial aggregation patterns of dynamic sparse attention, as shown above in Figure 3.1, and classify the dynamic sparse patterns into three types: A-shapeVertical-Slash, and Block-Sparse. MInference first determines the optimal dynamic sparse pattern for each head offline using the Kernel-Aware Sparse Pattern Search algorithm, as illustrated below in Algorithm 1. During inference, it dynamically approximates the dynamic sparse indices based on the head’s pattern, as shown below in Algorithms 2 and 3. Finally, we perform efficient dynamic sparse attention computation using our optimized GPU kernel, significantly reducing the pre-filling stage latency for long-context LLMs.

MInference: diagram showing three columns for Algorithm 1, Algorithm 2, and Alogrithm 3

For example, with the Vertical-Slash pattern, we first use the attention calculation between the last Q and K to estimate the optimal indices of vertical lines and slash lines. Then, we utilize the dynamic sparse compiler PIT and Triton to construct the Vertical-Slash FlashAttention kernel, accelerating the attention computation. For the A-shapeVertical-Slash, and Block-Sparse patterns, we first use the mean pooling of Q and K in attention calculations. By leveraging the commutative property of mean pooling and MatMul, we estimate the block-sparse indices. Then, we use Triton to construct the Block-Sparse FlashAttention kernel, accelerating the attention computation. For detailed kernel implementation, please refer to Appendix C.4 and the code.

Our main contributions are four-fold:

  1. We propose a dynamic sparse attention method, MInference, to accelerate the pre-filling stage of long-context LLMs by up to 10x for 1M token prompts while maintaining the capabilities of LLMs, especially their retrieval abilities, as demonstrated in tasks like Needle in a Haystack.
  2. We classify dynamic sparse attention in LLMs into three patterns and design the Kernel-Aware Sparse Pattern Search algorithm to find the optimal head pattern offline.
  3. We introduce an online approximate method and optimized GPU kernels to accelerate LLM inference with minimal overhead. We also propose an optimal inference codebase that enables 1M token pre-filling inference on a single A100 using LLaMA-style models.
  4. We evaluate MInference across four benchmarks: InfiniteBench, RULER, PG-19, and Needle in a Haystack, with token lengths ranging from 128k to 1M, to assess the actual context processing capabilities of LLMs. Experimental results reveal that MInference can maintain or slightly improve actual context processing capabilities, while also outperforming in terms of cost efficiency and system latency.

Experiments results in long-context benchmarks

We tested MInference across a range of scenarios, including QA, coding, retrieval-based tasks, multi-hop QA, summarization, and math tasks. The RULER benchmark includes several complex multi-hop or multi-needle tasks, effectively reflecting the actual context window size of LLMs. As shown in Table 1, our method effectively preserves the actual context window processing capability of LLMs and even slightly extends the actual context window size to 32K.

MInference: Table 1. Performance (%) of different models and different methods on RULER evaluated at lengths from 4k to 128k.
Table 1. Performance (%) of different models and different methods on RULER evaluated at lengths from 4k to 128k.

We also tested MInference on a broader range of tasks using the InfiniteBench, which has an average token length of 214K, as shown in Table 2. Compared to the SoTA baselines, MInference consistently maintains performance across all tasks. Notably, in the more challenging retrieval tasks like KV retrieval task, all baselines fail to make accurate predictions, with accuracy rates below 1.2%. However, MInference successfully retains the ability to handle dynamic KV pair retrieval.

MInference: Table 2. Performance of different methods with different base models on InfiniteBench.
Table 2. Performance of different methods with different base models on InfiniteBench.

To further evaluate performance across different context lengths and positions of key information within prompts, we tested various models and methods using the Needle in a Haystack task. As shown in Fig.(1), MInference performs well across different models, context windows, and positions within the prompt, maintaining or even slightly improving performance compared to the original models. In the case of LLaMA-3-8B and GLM-4-9B-1M, MInference achieves full green performance for context windows up to 1M. In comparison, StreamingLLM and InfLLM experience a performance drop to below 20% in the middle segments of prompts even in the 70K context windows.

MInference: Figure 1. Needle In A Haystack results using LLaMA-3-8B-Instruct-1M, GLM-4-9B-1M, Yi-9B-200K, Phi-3Mini-128K, and Qwen2-7-128K.
Figure 1. Needle In A Haystack results using LLaMA-3-8B-Instruct-1M, GLM-4-9B-1M, Yi-9B-200K, Phi-3Mini-128K, and Qwen2-7-128K.

We also tested MInference on the language model tasks using PG-19, which includes tokens up to 100k. As shown in Fig.(2), MInference effectively maintains the perplexity of LLaMA-3-8B and Yi-9B-200K, while all baselines experience varying degrees of perplexity drop. Additionally, it can be observed that StreamingLLM with dilated and strided configurations better maintain perplexity performance compared to the standard StreamingLLM.

MInference: Figure 2. Perplexity results on PG-19 using different models and methods.
Figure 2. Perplexity results on PG-19 using different models and methods.

Latency breakdown and sparsity pattern in the kernel

Figure 3.2 shows the micro-benchmark results of the three attention patterns proposed in this paper, as well as FlashAttention. It can be seen that Vertical-Slash is the slowest among the three patterns, but it still achieves a 13x speedup compared to FlashAttention under 1M context windows.

Figure 4 shows the sparse indices in the kernel of the Vertical-Slash head. The vertical lines are computed using 1×64 blocks through PIT FlashAttention, while the slash lines are computed using 64×64 blocks through Block-level FlashAttention.

MInference: Figure 3.2. The latency breakdown of a single attention kernel for three patterns and FlashAttention across different context windows in a single A100, including the index time for dynamic sparse approximation and building dynamic sparsity. At 10k tokens, the latency of the four kernels is very close and all are less than 1ms. At 1M tokens, the latency for A-shape is 164ms.
Figure 3.2. The latency breakdown of a single attention kernel for three patterns and FlashAttention across different context windows in a single A100, including the index time for dynamic sparse approximation and building dynamic sparsity. At 10k tokens, the latency of the four kernels is very close and all are less than 1ms. At 1M tokens, the latency for A-shape is 164ms.
MInference: Figure 4. The dynamic sparse mask in the kernel of the vertical-slash pattern schematic using LLaMA-3-8B in the summarization task, where the yellow areas indicate the parts actually involved in computation. The slash lines are covered using 64x64 block sizes, while the vertical lines are covered using 1x64 block sizes.
Figure 4. The dynamic sparse mask in the kernel of the vertical-slash pattern schematic using LLaMA-3-8B in the summarization task, where the yellow areas indicate the parts actually involved in computation. The slash lines are covered using 64×64 block sizes, while the vertical lines are covered using 1×64 block sizes.

BibTeX

If you find this project helpful, please cite the following papers:

@article{jiang2024minference,
    title={MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention},
    author={Jiang, Huiqiang and Li, Yucheng and Zhang, Chengruidong and Wu, Qianhui and Luo, Xufang and Ahn, Surin and Han, Zhenhua and Abdi, Amir H and Li, Dongsheng and Lin, Chin-Yew and Yang, Yuqing and Qiu, Lili},
    journal={arXiv preprint arXiv:2407.02490},
    year={2024}
}