Awesome
PyTorch-PKD-for-BERT-Compression
Pytorch implementation of the distillation method described in the following paper: Patient Knowledge Distillation for BERT Model Compression. This repository heavily refers to Pytorch-Transformers by huggingface.
Steps to run the code
1. download glue_data
$ python download_glue_data.py
2. Fine-tune teacher BERT model
By running following code, save fine-tuned model.
python run_glue.py \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $GLUE_DIR/$TASK_NAME \
--max_seq_length 128 \
--per_gpu_eval_batch_size=8 \
--per_gpu_train_batch_size=8 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--output_dir /tmp/$TASK_NAME/
3. distill student model with teacher BERT
$TEACHER_MODEL is your fine-tuned model folder.
python run_glue_distillation.py \
--model_type bert \
--teacher_model $TEACHER_MODEL \
--student_model bert-base-uncased \
--task_name $TASK_NAME \
--num_hidden_layers 6 \
--alpha 0.5 \
--beta 100.0 \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $GLUE_DIR/$TASK_NAME \
--max_seq_length 128 \
--per_gpu_eval_batch_size=8 \
--per_gpu_train_batch_size=8 \
--learning_rate 2e-5 \
--num_train_epochs 4.0 \
--output_dir /tmp/$TASK_NAME/
Experimental Results on dev set
model | num_layers | SST-2 | MRPC-f1/acc | QQP-f1/acc | MNLI-m/mm | QNLI | RTE |
---|---|---|---|---|---|---|---|
base | 12 | 0.9232 | 0.89/0.8358 | 0.8818/0.9121 | 0.8432/0.8479 | 0.916 | 0.6751 |
finetuned | 6 | 0.9002 | 0.8741/0.8186 | 0.8672/0.901 | 0.8051/0.8033 | 0.8662 | 0.6101 |
distill | 6 | 0.9071 | 0.8885/0.8382 | 0.8704/0.9016 | 0.8153/0.821 | 0.8642 | 0.6318 |