You Only Cache Once: Decoder-Decoder Architectures for Language Models
Authors: Yutao Sun, Li Dong, Yi Zhu, Shaohan Huang, Wenhui Wang, Shuming Ma, Quanlu Zhang, Jianyong Wang, Furu Wei
Paper: https://arxiv.org/abs/2405.05254
Code: https://github.com/microsoft/unilm/tree/master/YOCO
The authors have proposed an architecture for LLMs called decoder-decoder.
Encoders/decoders recap
Let’s recall that the original transformer (and models like T5) was built on a full encoder-decoder architecture. Most modern LLMs (like GPT) use only the decoder, while another popular branch from the recent past (BERT family models) consists only of the encoder. For more details please see the great post by Sebastian Raschka.
The encoder has always been bidirectional, and models with such a bidirectional component (i.e., encoder and encoder-decoder) faced issues with autoregressive generation. For generating a new token, you first had to encode the entire sequence from both the input and the already generated part of the output. You could, of course, use only the decoder part for generation, but then the generated tokens wouldn’t fully utilize the encoder’s parameters.
Decoders, on the other hand, can use KV-cache (to cache key and value vectors inside attention blocks) during autoregressive generation and reuse them for generating new tokens, eliminating the need to re-encode the entire history.
However, as the tale of Savitri goes, “there is one drawback.” The KV cache swells as the length of the generated sequence grows, consuming a lot of GPU memory and making LLMs memory-bound. For instance, a 65B model (with grouped-query attention and 8-bit KV quantization) requires 86GB of memory for 512k tokens, surpassing the H100-80GB memory limit. Moreover, the prefill phase (see the NVIDIA blog for a good description or a great overview by Pierre Lienhart here for more details on the phases), where all input prompt tokens need to be processed to compute KV values, can take hundreds of seconds for very long inputs like 1M tokens (it’s interesting to know what Google came up with for Gemini 1.5).
Enter YOCO
The solution is decoder-decoder!
The entire transformer of L layers is split equally, with the first L/2 layers implementing a self-decoder through efficient self-attention. The size of the KV cache for this part is constant, i.e., O(1).
The output of the last self-decoder layer provides a global KV cache, which the second half, the cross-decoder implemented through the remaining L/2 layers, accesses. Each block receives Q input and uses cross-attention to access this global KV cache. Here, standard multi-head attention with a full window is employed (OK, almost standard, it uses GQA, for more details see “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints“, https://arxiv.org/abs/2305.13245).
By efficient self-attention in the self-decoder, the authors mean sliding-window attention (SWA) like in the old sparse transformer by Ilya Sutskever and co (“Generating Long Sequences with Sparse Transformers“, https://arxiv.org/abs/1904.10509). Alternatively, RetNet (named gRet, aka gRetNet or RetNet-3) with data-dependent gating can be used in the self-decoder. This seems the same as what was described in the original RetNet paper (“Retentive Network: A Successor to Transformer for Large Language Models“, https://arxiv.org/abs/2307.08621).
Otherwise, the blocks in these layers are generally standard, alternating attention and FFN, using pre-RMSNorm, SwiGLU, GQA.
This resulting architecture is called YOCO (You Only Cache Once, implying caching at the L/2 layer). It resembles an encoder-decoder but looks like a decoder from the outside, with both parts using causal masking.
YOCO is more efficient than a regular transformer due to its lower memory requirements, as the cache for long sequences scales as O(N) instead of O(NL), here N is the input sequence length, and L is the number of layers. This allows for faster inference and/or larger batches (increasing throughput).
Another interesting feature of YOCO is that during the prefill stage, you can perform an early exit and skip the cross-decoder, speeding up this phase. Since the self-decoder contains half the layers, this already reduces computation and time by half. Additionally, efficient attention implementation in the self-decoder is usually fast. The authors provide an example with a 512K context size, where prefill latency drops from 180 seconds (far from the worst transformer with flash-decoding and kernel fusion) to less than 6 seconds. Even at 32K length, YOCO is still three times faster (in this phase, not overall end-to-end).
Results
In tests, they used StableLM-3B-4E1T as a baseline and created a comparable YOCO, yielding results similar to other well-tuned models of the same size.
The model loss scales similarly to Llama-optimized transformers. YOCO with gRet performs slightly better than with sliding-window attention (SWA) and a regular transformer.
Expanding YOCO-3B’s context to 1M (hello, Gemini!) through continued training with length schedule 64K, 256K, 1M shows nearly perfect results on Needle In A Haystack test.
There is a comparison with Mamba, RetNet, Hybrid H3, gRetNet, and transformer in applications. YOCO and transformer lead (in perplexity).
The most exciting results are in performance.
Memory improvement is significant, and the longer the sequence length, the greater the improvement.
At 1M length, YOCO consumes 9.38x less memory than a transformer with GQA, Flash-Decoding, and kernel fusion. Mainly due to the KV cache, but gRet also seems to slightly improve activation storage.
Prefilling latency metric shows a tenfold improvement.
Throughput (tokens per second) on long inputs accelerates almost tenfold (mainly due to faster prefill and the ability to use larger batches due to better memory handling).
A good piece of engineering work, I like it. Combinations like YOCO + BitNet + Groq might have a cumulative effect, making it a bombshell.