Awesome
Structural Pruning of Pre-trained Language Models via Neural Architecture Search
This package provides code to reproduce experiments from Klein et al. which proposes multi-objective Neural Architecture Search (NAS) to prune pre-trained language models by searching for sub-networks that minimize both validation error and parameter count. If you use this code please cite the original paper:
@article{klein-tmlr24,
title={Structural Pruning of Pre-trained Language Models via Neural Architecture Search},
author={Aaron Klein, Jacek Golebiowski, Xingchen Ma, Valerio Perrone, Cedric Archambeau},
journal={Submitted to Transactions on Machine Learning Research},
year={2024},
url={https://openreview.net/forum?id=XiK8tHDQNX},
note={Under review}
}
We distinguish between standard NAS, which fine-tunes each sub-network in isolation and weight-sharing based NAS. Our weight-sharing based NAS approach consists of two stages:
- We first fine-tune the pre-trained network (dubbed super-network) via weight-sharing based NAS strategies. In a nutshell, in each update steps, we only update parts of the network to train different sub-networks.
- In the second stage, we run multi-objective search to find the Pareto set of sub-networks of the super-network. To evaluate each sub-network we use the shared weights of the super-networks, without any further training. This is relatively cheap compared to standard NAS, since we only do a single pass over the validation data without computing gradients.
Install
To get started, first install whittle following the installation instructions here
Afterwards, we install the dependencies via:
cd src
pip install -r requirements.txt
Benchmarking Details
At the moment we support models from the BERT and RoBERTa family and the following datasets:
[rte', 'mrpc', 'cola', 'stsb', 'sst2', 'qnli', 'imdb', 'swag', 'mnli', 'qqp']
Also you can use the following multi-objective methods from Syne Tune both for standard NAS and weight-sharing based NAS:
['random_search', 'morea', 'local_search', 'nsga2', 'moasha', 'ehvi']
Standard NAS
To run standard NAS, use the following script. This will run NAS using Syne-Tune to prune a BERT-base-cased model using random search for 3600 seconds on the RTE dataset.
python src/run_nas.py --output_dir=./output_standard_nas --model_name bert-base-cased --dataset rte --runtime 3600 --method random_search --num_train_epochs 5 --seed 0 --dataset_seed 0
Weight-sharing NAS
As described above, weight-sharing NAS runs in two phases. We first fine-tune the super-network and store the checkpoint on disk. Afterward, we can run our multi-objective search using the same algorithms as for standard NAS, except for MO-ASHA which only works in a multi-fidelity setting.
Super-Network Training
To run the training of the super-network, execute the following script:
python src/train_supernet.py --learning_rate 2e-05 --model_name_or_path bert-base-cased --num_train_epochs 5 --output_dir ./supernet_model_checkpoint --save_strategy "epoch" --per_device_eval_batch_size 8 --per_device_train_batch_size 4 --sampling_strategy one_shot --save_strategy epoch --search_space small --seed 0 --task_name rte --num_random_sub_nets 2 --temperature 10
This runs the super-network training ('one_shot') on the RTE dataset for 5 epochs. Checkpoints are saved in the
output_dir
, such that we can load it later for the multi-objective search.
Most hyperparameters follow the HuggingFace training arguments. At this point we
support the following super-network training strategies: ['standard', 'random', 'linear_random', 'one_shot', 'sandwich', 'kd']
.
See the paper for a detailed description.
Multi-Objective Search
Next, we use the model checkpoint from the previous step to perform the multi-objective search:
python src/run_offline_search.py --model_name_or_path bert-base-cased --num_samples 100 --output_dir ./results_nas --checkpoint_dir_model ./supernet_model_checkpoint --search_space small --search_strategy random_search --seed 0 --task_name rte
Make sure that checkpoint_dir_model
points to the directory with the model checkpoint from the previous step.
Results will be saved as a json file in output_dir
.