Authors: Mrigank Raman, Pranav Mani, Davis Liang, Zachary C. Lipton
Paper: https://openreview.net/pdf?id=2fc5GOPYip
News on distillation. The paper is from the Instruction Tuning and Instruction Following workshop at NeurIPS 2023.
TL;DR The authors proposed a method of LLM distillation named SLIM, which uses the top 5% logits values for distillation at each decoding step + dynamic weighting of KL/CE losses. The result is better than classical distillation, SFT, and MiniLLM. The method scales to teachers of around ~70B in size.
In more detail, modern LLMs have grown to huge sizes and are often used to annotate and generate instructions for fine-tuning smaller models. The classic fine-tuning method is supervised fine-tuning (SFT), retraining on new (generated) texts as hard labels. We know that this method carries significantly less information about the distribution than it could (link).
Distillation is already used for LLMs (and many small LLMs are derived from larger ones, like Gemini, for instance), and it has its complexities. For example, a small model may not be expressive enough to cover all modes of the teacher's distribution.
The recent MiniLLM method (https://arxiv.org/abs/2306.08543) replaces forward Kullback-Leibler divergence (KLD) with reverse KLD to prevent the student model from overestimating low-probability regions of the teacher's distribution. It used RL for training, adding complexity. The current work simplifies and improves this, with the approach named SLIM (Sparse Logit Infused Modeling).
The idea is simple. Let's create a logits dataset (outputs of neurons without applying the activation function) through the teacher model from our training dataset. For each token in the sequence, we accordingly get V (vocabulary size) values, which will be soft targets. The problem with this approach is that it requires a lot of space. To reduce the requirements, it is proposed to take only the top 5% of logits for each token, considering the rest as zeros, thus resulting in sparse logits.
Next, we start the distillation process, where the targets are soft targets and the loss is a weighted sum of the usual cross-entropy loss (CE) and traditional KL loss. The weight given to the KD loss depends on the ratio of logits of the teacher and the student, making it adaptive and allowing the KL component to have a larger contribution when the teacher is more confident in the prediction than the student.
The approach was tested on instruction-following tasks. It was evaluated using Rouge-L and through feedback from GPT-4. It was compared with SFT on hard labels and MiniLLM, using 7B LLaMA, LLaMA 2, MPT models for training, with a 13-30B teacher. SLIM gives better results than the baselines. SFT was the worst of all.
Then it was tested on downstream tasks: ARC, Hellaswag, MMLU, TruthfulQA. Here, it was compared with SFT using a LLaMA 2 70B as the teacher and Llama2-7B/13B as students. It performed better here too.
It was also tested on data generation for pre-training. They took Pythia-6.9B and generated a dataset of texts + 5% logits. Then they trained a randomly initialized Pythia-160M on subsets of the dataset of different sizes. The perplexity graph showed that SLIM is more sample efficient than SFT and vanilla distillation. I'm not entirely clear what exactly they mean by vanilla distillation here, perhaps more classical KL loss without the top 5%?
Anyway, it's simple and effective. The method doesn't seem much different from classical distillation; I wouldn't call it radically new, more of an iterative improvement. It's interesting to see how companies like OpenAI and Google distill their models internally. Is there a big difference?
The problem is that this method is hard to apply with blackbox models like those from OpenAI, but with new parameters for logprobs and top_logprobs one can get closer to this. However, the Terms of Use seems to prohibit this (“Use Output to develop models that compete with OpenAI”), but here I don't quite understand how other models are trained on datasets generated by GPT-4, there it seems possible?