Home

Awesome

APT

APT: Adaptive Pruning and Tuning Pretrained Language Models for Efficient Training and Inference

News

Overview

The APT framework We propose APT, a methodology that Adaptively selects model parameters for Pruning and fine-Tuning. APT combines the benefits of PEFT and structured pruning to make fine-tuning and inference more efficient.

How APT Works

Our intuition is that pretrained language model (LM) parameters contain general knowledge, but their importance to downstream tasks varies. Therefore, we can remove the parameters irrelevant to the fine-tuning task in the early training stage. Early-removing these parameters improves training and inference efficiency while not substantially hurting model accuracy. Meanwhile, continuously adding more parameters for fine-tuning can improve LM performance because task-specific skills live in a subset of LM parameters.

Based on this setup, we find that using self-distillation where the main parameters between the teacher and student models are shared can vasly prune small LMs with high end-task performance retained. Meanwhile, considering in-block outliers by calculating kurtosis when pruning large LMs before training can accurately prune them with less training memory footprint.

Main Results

RoBERTa-base experiment results:

MethodMNLISST2SQuAD v2Train TimeTrain Mem.Inf TimeInf Mem.
FT87.694.882.9100.0%100.0%100.0%100.0%
LoRA87.595.183.02137.0%60.5%100.0%100.0%
LoRA+Prune84.093.079.25128.3%60.5%38.0%75.1%
Prune+Distill87.394.5-1495.3%168.5%38.6%79.2%
LoRA+Prune+Distill84.291.9-6534.6%141.4%39.4%82.3%
APT86.494.581.8592.1%70.1%41.3%78.1%

T5-base experiment results:

MethodMNLISST2CNN/DMTrain TimeTrain Mem.Inf TimeInf Mem.
FT87.195.242.1/20.3/39.4100.0%100.0%100.0%100.0%
LoRA87.095.038.7/17.2/36.0255.5%62.0%100.0%100.0%
LoRA+Prune80.992.336.7/15.7/33.94523.5%62.0%47.1%73.4%
APT87.095.038.6/17.0/35.8484.7%73.9%74.6%81.5%

LLaMA-7B experiment results:

MethodARCHellaSwagMMLUTruthfulQAAvg.Train TimeTrain Mem.Inf TimeInf Mem.
LLaMA 2 7B53.177.743.839.053.4----
LoRA55.679.346.949.957.9100.0%100.0%100.0%100.0%
LoRA+Prune46.865.223.946.245.5180.9%100.0%115.5%68.9%
LLMPruner39.267.024.940.642.986.9%253.6%114.8%74.2%
APT45.471.136.946.650.0106.0%75.8%117.0%67.2%

Setup

Installation

conda env create -f environment.yml
conda activate apt

Training

For finetuning RoBERTa-base models with APT, please run:

bash scripts/adaptpruning/roberta_base_sst2_momentum.sh

For finetuning T5-base models with APT, please run:

bash scripts/adaptpruning/t5_base_lm_adapt_cnndm_momentum.sh

For finetuning LLaMA2 models on Alpaca with APT, please run:

bash scripts/adaptpruning/llama_2_7b_alpaca_gpt4.sh

Citation

If you use this code or our tuned models, please cite our paper:

@misc{zhao2024apt,
      title={APT: Adaptive Pruning and Tuning Pretrained Language Models for Efficient Training and Inference}, 
      author={Bowen Zhao and Hannaneh Hajishirzi and Qingqing Cao},
      year={2024},
      eprint={2401.12200},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Acknowledgements

This project uses modified code from the following projects: