Title: Walking Down the Memory Maze: Beyond Context Limit through Interactive Reading
Authors: Howard Chen, Ramakanth Pasunuru, Jason Weston, Asli Celikyilmaz
Paper: https://arxiv.org/abs/2310.05029
The Eternal Challenge of Transformers: Limited Context and Handling Long Inputs
Transformers have revolutionized many areas of machine learning, but they come with their own set of challenges. A perennial issue is their limited ability to handle long input sequences due to fixed-length context windows. Various strategies have been employed to address this limitation, ranging from simply expanding the context window—often coupled with modifications to the attention mechanism—to more complex solutions.
Several approaches involve tinkering with the attention mechanism itself, such as sparse attention, linear attention, or at least sub-quadratic complexity attention. Examples of these models include the Reformer, Longformer, Linformer, and Big Bird, each representing a unique attempt to handle longer contexts more efficiently.
Another method involves extrapolating positional embeddings, a technique somewhat adjacent to the primary attention mechanism modifications.
Some solutions introduce an element of recurrence into the primarily feed-forward nature of transformers. This lineage can be traced back to innovations like the Transformer-XL, Compressive Transformer, Feedback Memory, and Recurrent Memory Transformer, or RMT (I even participated in work on its predecessor, Memory Transformer).
Not far from it stay models offering a retrieval-augmented approach, like RETRO, Unlimiformer, and myriads of others.
Yet another approach involves agents that can interact with text segments and perform specific actions, akin to the strategy employed by WebGPT or various iterative prompting methods.
Despite these advancements, the quest for an efficient and comprehensive solution continues.
Enter MemWalker
MemWalker is a novel solution proposed in the current study. It operates in two stages:
Memory Tree Construction: The first phase involves building a "memory tree"—a hierarchical structure summarizing chunks of input data. The long input sequence is sliced into portions that fit within the model's context window. Each piece is then condensed into a summary, and these summaries are further summarized at the next level, creating a tree-like hierarchical structure. Notably, this tree is query-agnostic, meaning it can be computed in advance and utilized for various inquiries.
Essentially, two prompts are used for generating summaries: one for the leaves (summaries of text chunks) and another for the nodes (summaries of summaries). For the nodes, we summarize as many previous-level summaries as possible (it is limited by the context window size), then the process is repeated for the remainder.
Navigation: Upon receiving a query, MemWalker navigates the memory tree, searching for relevant information, starting from the root. Once sufficient information is gathered, it generates a response.
This stage also uses two prompts: a "leaf prompt" for the leaves and a "triage prompt" for the nodes. At each node, the language model (LLM) receives summaries from all child nodes. The prompt then requests the model to choose (with reasoning via a Chain-of-Thought, or CoT, approach using the “First provide reasoning to compare the summaries before you make the decision“) which passage is most likely to contain the answer to the query. The study indicates that if no relevant information is found, the model should revert to the parent node, although this isn't explicitly stated in the prompt (or at least I haven’t noticed). If the model reaches a leaf, it either accepts it and responds or reverts to the parent node.
Responses are required in a specific format. If the LLM fails, it's asked to regenerate the response. After three consecutive failures, navigation ceases with a "no answer." During navigation, a form of working memory is maintained and added to the leaf prompt, seemingly containing the content of parent nodes.
The logic behind orchestrating this entire process is poorly described, with many assumptions and non-reproducible work in its pure form. At the very least, there should be clear tracking of where the model has already been, to avoid returning to the same node when going back to the parent. Alternatively, this could be implicitly handled through the search procedure, but none of this is outlined.
Evaluation
The performance was tested for accuracy on three datasets: QuALITY, SummScreenFD, and GovReport from the SCROLLS benchmark.
QuALITY is a multiple-choice question-answering dataset based on long texts from Project Gutenberg, from which 187 examples were taken.
SummScreenFD contains dialogues from TV series actors. Originally created for summarization, it was here transformed into a question-and-answer format, with 306 examples taken.
GovReport includes documents from the Congressional Research Service and the U.S. Government Accountability Office, along with expert summaries. This was also converted into a question-and-answer format, and 101 examples were used.
The results were reported on these datasets, as well as on subsets with longer documents (each dataset had its own threshold, ranging from 6 to 12k tokens).
As for the LLM, Stable Beluga 2 was used (https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models), which is a fine-tuned version of Llama 2 with 70 billion parameters. The context length was 4096 tokens. There was no further fine-tuning, nor few-shot learning; the model was used in zero-shot mode.
The maximum number of nodes at the tree level was 8, 5, and 8, and the segment size was 1000, 1000, and 1200 tokens, respectively, for these three datasets.
Three baselines were used for comparison:
Full context window with a trim down to 4096 tokens.
Recurrency through summarization, where each segment is 2500 tokens, and the maximum summary size is 500 tokens.
Retrieval -- using Contriever (https://github.com/facebookresearch/contriever) to select segments for the query.
The recurrency through summarization performs the worst. This specific retrieval baseline is average. Full context performs quite well, and, depending on the dataset, trimming is better either from the left or the right. Sometimes it's comparable to MemWalker, but overall the latter is better. On a subset of particularly long documents, MemWalker is consistently better.
The authors also compared with LongChat 13B (16k) and MPT 13B (8k). They are worse, but they are also significantly lighter compared to the 70B model. Running MemWalker on LLaMA 2 Chat 13B also yields pretty poor results.
In general, it's hard to really evaluate; it would be interesting to compare all this on one model with a larger context. Or even better, on different ones, including Claude, which has a context of 100k tokens, and GPT-4 with 32k. The fact that full context yields a very high result suggests that a model with a larger context might work well out of the box.
Authors separately checked how useful CoT is with this "First provide reasoning…". Actually, for LLaMA 2 Chat 13B and 70B, it's better without it. For Stable Beluga 2 70B, it's better with it. The addition of working memory is also noticeably better.
The authors believe that a large instruction-tuned model with reasoning ability is necessary for functioning. But to me, honestly, it seems that not enough checks were made for this; they just happened to have a model that works better through CoT. Whether it's necessary or not, who knows.
During navigation in the tree, a rollback to the parent node and changing the path in the tree happens in 15-20% of cases, and out of these cases, 60-80% yield the correct result.
Well, in general, it's an intriguing technique. This allows working with data larger than the model's context window allows. At the same time, there's no need to additionally train the model. Only the logic of orchestrating this activity is needed. This falls back again on the concept of LLM Programs, as, for example, Tree-of-Thought, ToT. And basically, this is such a variation of ToT, just with a preprocessing stage (building a tree).
It doesn't seem like a direct game-changer, but it may take its place in the arsenal. There seems to be a growing need for a library of standard algorithms on top of LLMs, like STL or Boost for the new era.