Home

Awesome

Dual-Space Knowledge Distillation for Large Language Models (EMNLP 2024)

<small>Songming Zhang, Xue Zhang, Zengkui Sun, Yufeng Chen*, Jinan Xu</small>

<a href="https://arxiv.org/abs/2406.17328"><img src="https://img.shields.io/badge/Paper-arXiv:2406.17328-Green"></a> <a href=#bibtex><img src="https://img.shields.io/badge/Paper-BibTex-yellow"></a>

Some of our code follows MiniLLM and Distillm.

News

Requirements

Data

The processed data used in our paper can be downloaded here.

Models

You can download the corresponding model files (e.g., pytorch_model.bin or model.safetensors) of LLMs used in this paper into model_hub/*/*/.

Here are the links of these models on huggingface:

Training

SFT for teacher models

For Qwen1.5-1.8B (full fine-tuning), run:

bash scripts/gpt2/sft_teacher_qwen.sh

For LLaMA2-7B (LoRA), run:

bash scripts/tinyllama/sft_teacher_llama2.sh

For Mistral-7B (LoRA), run:

bash scripts/tinyllama/sft_teacher_mistral.sh

SFT for student models

For GPT2-base (full fine-tuning), run:

bash scripts/gpt2/sft_gpt2_base.sh

For TinyLLaMA-1.1B (LoRA), run:

bash scripts/tinyllama/sft_tinyllama.sh

P.S. You may encounter an error when directly loading the model checkpoint of TinyLLaMA. This is because of the mismatched versions of transformers between TinyLLaMA suggested (4.31) and the one you use. A concise solution to fix this can be referred to in this issue.

KD for the Same Vocabulary

Vanilla KD framework

For GPT2-base, run:

bash scripts/gpt2/vanilla_kd_gpt2_base.sh

For TinyLLaMA-1.1B, run:

bash scripts/tinyllama/vanilla_kd_tinyllama.sh

You can change the distance functions (e.g., KL Divergence, Reverse KL Divergence, JS Divergence, etc.) using KD_OBJ in the above scripts.

Dual-Space KD framework

For GPT2-base, run:

bash scripts/gpt2/dskd_gpt2_base.sh

For TinyLLaMA-1.1B, run:

bash scripts/tinyllama/dskd_tinyllama.sh

Also, you can change the distance functions using KD_OBJ in the above scripts.

KD for different vocabularies

Logits Alignment by Minimum Edit Distance (paper, original implementation)

The original implementation in this repo pre-processes the logit alignment before distillation, while we re-implement this method by faster calculating alignment during distillation in code/criterions/min_edit_dis_kld.py.

For GPT2-base, run:

bash scripts/gpt2/minedit_gpt2_base.sh

For TinyLLaMA-1.1B, run:

bash scripts/tinyllama/minedit_tinyllama.sh

Universal Logit Distillation (paper, original implementation)

We also re-implement this method in code/criterions/universal_logit_distillation.py.

For GPT2-base, run:

bash scripts/gpt2/uld_gpt2_base.sh

For TinyLLaMA-1.1B, run:

bash scripts/tinyllama/uld_tinyllama.sh

Our Dual-Space KD with Cross-Model Attention (CMA)

For GPT2-base, run:

bash scripts/gpt2/dskd_cma_gpt2_base.sh

For TinyLLaMA-1.1B, run:

bash scripts/tinyllama/dskd_cma_tinyllama.sh

File Structures in Output Directory

The output directory will be created under ./outputs automatically after you run the training scripts. For full fine-tuning, the file structure of the output directory is as follows (take gpt2 SFT as an example):

./outputs/gpt2/gpt2-base/sft/criterion=cross_entropy__default-bf16__.../
ā”‚
ā”œā”€ā”€ epochA_step... (model files of epoch A, you can directly load it by AutoModelForCausalLM.from_pretrained(this path))/
ā”‚   ā”œā”€ā”€ config.json
ā”‚   ā””ā”€ā”€ pytorch_model.bin
ā”‚   ā””ā”€ā”€ tokenizer.json
ā”‚   ā””ā”€ā”€ ...
ā”‚
ā”œā”€ā”€ epochB_step... (only exists when SAVE_BEST_N_CKPTS >= 2, similar to epochA_.../)/
ā”‚   ā”œā”€ā”€ config.json
ā”‚   ā””ā”€ā”€ pytorch_model.bin
ā”‚   ā””ā”€ā”€ tokenizer.json
ā”‚   ā””ā”€ā”€ ...
ā”‚
ā””ā”€ā”€ ...
ā”‚
ā””ā”€ā”€ args.json (The arguments of training)
ā”‚
ā””ā”€ā”€ train.log (Training log)

For LoRA fine-tuning, the file structure of the output directory is as follows (take TinyLLaMA LoRA SFT as an example):

./outputs/tinyllama/tinyllama-1.1b-3T/sft/criterion=cross_entropy__lora-rank=256-alpha=8.../
ā”‚
ā”œā”€ā”€ epochA_step... (model files of epoch A, you can directly load it by AutoModelForCausalLM.from_pretrained(this path))/
ā”‚   ā”œā”€ā”€ adapter_config.json
ā”‚   ā””ā”€ā”€ adapter_model.bin
ā”‚   ā””ā”€ā”€ tokenizer.json
ā”‚   ā””ā”€ā”€ ...
ā”‚
ā”œā”€ā”€ epochB_step... (only exists when SAVE_BEST_N_CKPTS >= 2, similar to epochA_.../)/
ā”‚   ā”œā”€ā”€ adapter_config.json
ā”‚   ā””ā”€ā”€ adapter_model.bin
ā”‚   ā””ā”€ā”€ tokenizer.json
ā”‚   ā””ā”€ā”€ ...
ā”‚
ā””ā”€ā”€ ...
ā”‚
ā””ā”€ā”€ args.json (The arguments of training)
ā”‚
ā””ā”€ā”€ train.log (Training log)

Evaluation

Evaluate Full Fine-tuning Checkpoints

bash scripts/eval/run_eval.sh ${CKPT_PATH} ${EVAL_BATCH_SIZE}

According to the above structure, CKPT_PATH is the absolute path of the model files like /home/xxx/DSKD/outputs/gpt2/gpt2-base/sft/criterion=cross_entropy__default-bf16__.../epochA_step....

Evaluate LoRA Fine-tuning Checkpoints

bash scripts/eval/run_eval_lora.sh ${LORA_ADAPTER_PATH} ${EVAL_BATCH_SIZE}

Please note that MODEL_PATH in run_eval_lora.sh should be changed for different base models (TinyLLaMA, LLaMA2, Mistral).

Similarly, LORA_ADAPTER_PATH is the absolute path of the LoRA adapter files like /home/xxx/DSKD/outputs/tinyllama/tinyllama-1.1b-3T/sft/criterion=cross_entropy__lora-rank=256-alpha=8.../epochA_step....

BibTeX

If you find this repo useful for your research, please consider citing our paper:

@article{zhang2024dskd,
      title={Dual-Space Knowledge Distillation for Large Language Models}, 
      author={Songming Zhang and Xue Zhang and Zengkui Sun and Yufeng Chen and Jinan Xu},
      year={2024},
      journal={arXiv preprint arXiv:2406.17328},
}