Star Attention: Efficient LLM Inference over Long Sequences
NVIDIA's approach for sharding attention
Authors: Shantanu Acharya, Fei Jia, Boris Ginsburg
Paper: https://arxiv.org/abs/2411.17116
Code: https://github.com/NVIDIA/Star-Attention
Attempts to optimize classical quadratic attention in transformers keep coming. This time, a block-sparse attention approximation from NVIDIA shards it across multiple hosts and minimizes communication overhead, enabling efficient processing of very long contexts. This work gives lots to think about.
The proposed Star Attention is based on the observation that LLM inference consists of two phases:
1) encoding the prompt and filling the KV cache for it,
2) autoregressive token generation with updates to this cache. In the case of large context, the working pattern is often "very long context + short query + short response."
Star Attention itself is also two-phased:
1. (local block-wise attention) Context Encoding: the entire context is split into adjacent blocks (each with b tokens) and distributed across "context" hosts. Each host also receives the first block ("anchor block"). Hosts compute self-attention only within their blocks (2b except for the host with only the first block, which has b), without communicating with each other, thus achieving linear complexity. Hosts also fill the KV-cache, excluding the anchor block.
2. (global attention) Query Encoding and Token Generation: the query is replicated across all hosts, where it attends to the local KV-cache. Global attention is then computed by aggregating results on a designated "query" host. The query host also supplements the KV-cache with values for newly generated tokens.
Interestingly, the first phase doesn't work without anchor blocks - the model doesn't generate correct results. The authors suggest this is due to incorrect attention approximation in the second phase. If each block is processed independently, attention sinks (https://arxiv.org/abs/2309.17453) emerge at the beginning of blocks, and after aggregation, this prevents the model from focusing on relevant parts of the context. Anchor blocks draw these sinks to themselves, and since these blocks don't enter the KV cache, the attention distribution across blocks well approximates global attention and the problem doesn't arise (illustrated in Figure 3). Interestingly, perhaps they could have added several arbitrary pseudo-tokens instead of the entire anchor block? StreamingLLM from the original paper about sinks seemed to do something similar. But, spoiler, read the ablations.
The second stage effectively computes distributed softmax without needing to exchange KV-caches between hosts. There's only sending to the query host locally computed vector of attention coefficients (after local softmax), as well as (a scalar) sums of local exponentials (denominator in local softmax), which are needed for renormalizing global softmax. Thus, global attention coefficients are computed correctly.
Star attention is integrated without fine-tuning into pre-trained LLMs with full attention, these are different variants of Llama-3.1 8B, as well as Llama-3.1-70B-Instruct for scalability testing. As a baseline, they use Ring Attention (https://arxiv.org/abs/2310.01889), which also split everything into blocks across different hosts and allowed scaling sequence length by number of hosts, but didn't make any approximations to the attention mechanism - it was full attention. There, hosts exchanged their KV caches in a ring pattern.
Regarding testing on the RULER benchmark with inputs from 16K to 128K tokens and block size of ¼ of sequence length, it's written that accuracy drops insignificantly, from 0 to 3%, while speed increases up to 5 times (and even more on the 70B model). However, the table shows something slightly different - on the 7B model, speedup is up to 2.7x, and accuracy increases.
I don't really understand, by the way, why accuracy is higher, aren't we comparing against a non-approximated baseline? I know, of course, stories when approximation unexpectedly turns out better than the original due to bugs in the original (Ash recently shared a story about that). It still seems strange, but, spoiler, it will become clearer later. On the 70B model, accuracy drops, but acceleration is higher, up to 4.7x. On BABILong, quality is also comparable.
At length 128K, they looked at tradeoffs between block size and accuracy. Logically, with larger block size (when one host sees more), quality is higher.
Then, they kept the block at 32K and increased the input size to 1024K for 8B Llama. Quality dropped to 8%+, but the speedup was almost 17x. On 70B Llama, they made block 16K for 128K input—here quality -11.5%, speed +8.7x.
I think these are quite noticeable quality drops. It would be interesting to compare this with a fundamentally different alternative - quantization. What acceleration/degradation do you get there, and if comparable, then quantization has a win on another dimension that Nvidia understandably doesn't look at much - the hardware used and its cost.
But in any case, the ability to shard computations is good, especially since these are all orthogonal stories and when it combines with low-precision training (and something seems to be happening there https://arxiv.org/abs/2307.00331), it will be especially interesting to look at new model scales. I see that Intel is also trying to do something here (https://github.com/intel/intel-extension-for-transformers, but this isn't about training), but honestly, I don't understand whether to expect anything from them at all.
In terms of different task types on RULER, Star attention performs differently.
Single Needle-in-a-Haystack (NIAH) is practically identical to full attention. Multi-NIAH and QA drop more noticeably. Multi-Hop Tracing even more so, as it requires information propagation and effective communication, which isn't present here, so this is logical. Aggregation unexpectedly increases significantly in quality, and here, it apparently comes from better summarization within individual blocks in the first phase of the algorithm. In full attention, apparently, the model gets more distracted by global context, and summarization across the entire document suffers because of this.
They did some ablations on NIAH: Single-NIAH, Multi-key NIAH, Multi-query NIAH.
They tested two interesting hypotheses about anchor blocks:
1) The model developed a bias to the absolute block position
2) The semantic content of this block is important for quality.
For 1) they varied block positions while maintaining its content. This has little effect.
For 2) they varied block content (filling it with specific tokens, random or doing shuffle). Very impactful. Constant tokens are the worst, even worse than without a block at all. Original content is the best. In a separate experiment, they tested what if using the previous block as an anchor block. It's still worse than the original but better than other options. So it's not about position, it's about content.
As this block size increases, quality also increases.
This is, of course, an interesting story about what's so special about these blocks in different tasks that makes them so important. Surely, there should have been works evaluating uneven content significance in different positions of real datasets, especially in large contexts. Share if you know good examples. It smells like a certain inductive bias of the real tasks themselves (and we humans have adapted or co-evolved for that), which makes sense overall.
Well, interesting stuff. Divide and conquer, basically. Or Map & Reduce.