Home

Awesome

PyTorch-PKD-for-BERT-Compression

Pytorch implementation of the distillation method described in the following paper: Patient Knowledge Distillation for BERT Model Compression. This repository heavily refers to Pytorch-Transformers by huggingface.

Steps to run the code

1. download glue_data

$ python download_glue_data.py

2. Fine-tune teacher BERT model

By running following code, save fine-tuned model.

python run_glue.py \
    --model_type bert \
    --model_name_or_path bert-base-uncased \
    --task_name $TASK_NAME \
    --do_train \
    --do_eval \
    --do_lower_case \
    --data_dir $GLUE_DIR/$TASK_NAME \
    --max_seq_length 128 \
    --per_gpu_eval_batch_size=8   \
    --per_gpu_train_batch_size=8   \
    --learning_rate 2e-5 \
    --num_train_epochs 3.0 \
    --output_dir /tmp/$TASK_NAME/

3. distill student model with teacher BERT

$TEACHER_MODEL is your fine-tuned model folder.

python run_glue_distillation.py \
    --model_type bert \
    --teacher_model $TEACHER_MODEL \
    --student_model bert-base-uncased \
    --task_name $TASK_NAME \
    --num_hidden_layers 6 \
    --alpha 0.5 \
    --beta 100.0 \
    --do_train \
    --do_eval \
    --do_lower_case \
    --data_dir $GLUE_DIR/$TASK_NAME \
    --max_seq_length 128 \
    --per_gpu_eval_batch_size=8   \
    --per_gpu_train_batch_size=8   \
    --learning_rate 2e-5 \
    --num_train_epochs 4.0 \
    --output_dir /tmp/$TASK_NAME/

Experimental Results on dev set

modelnum_layersSST-2MRPC-f1/accQQP-f1/accMNLI-m/mmQNLIRTE
base120.92320.89/0.83580.8818/0.91210.8432/0.84790.9160.6751
finetuned60.90020.8741/0.81860.8672/0.9010.8051/0.80330.86620.6101
distill60.90710.8885/0.83820.8704/0.90160.8153/0.8210.86420.6318