Home

Awesome

BERT Defender

Introduction

This repo contains code for the following paper.

Learning to Discriminate Perturbations for Blocking Adversarial Attacks in Text Classification, <br> Yichao Zhou, Jyun-Yu Jiang, Kai-Wei Chang and Wei Wang. EMNLP 2019.

In this paper, we propose a novel framework, learning to discriminate perturbations (DISP), to identify and adjust malicious perturbations, thereby blocking adversarial attacks for text classification models.

<img src="https://yz-joey.github.io/images/flow.png" style="zoom:30%;" />

Requirements

Python 3.6
Pytorch 1.0.1+
CUDA 10.0+
numpy
hnswlib
tqdm

Pre-training Discriminator

We first attack the training data on word level or character level. Then we pre-train a discriminator with the adversarial data.

python bert_discriminator.py 
--task_name sst-2 
--do_train  
--do_lower_case   
--data_dir data/SST-2/   
--bert_model bert-base-uncased   
--max_seq_length 128   
--train_batch_size 8   
--learning_rate 2e-5   
--num_train_epochs 25
--output_dir ./tmp/disc/

Pre-training Embedding Estimator

We build a pre-training dataset for embedding estimator by collecting the context of window size for each word in the dataset. It can also be considered as fine-tuning a bert language model using a smaller corpus. The embedding estimator is different from a language model because it only estimate the embedding for a masked token instead of using a huge softmax to pinpoint the word.

python bert_generator.py 
--task_name sst-2 
--do_train  
--do_lower_case   
--data_dir data/SST-2/
--bert_model bert-base-uncased  
--max_seq_length 64   
--train_batch_size 8  
--learning_rate 2e-5   
--num_train_epochs 25
--output_dir ./tmp/gnrt/

Inference

We first attack the test data using 5 differernt methods to drop the model performance as much as possible. The codes related to attacking the test sets would be availble soon!

During inference phase, we use the pre-trained discriminator to identify the words that have been attacked.

python bert_discriminator.py 
--task_name sst-2 
--do_eval 
--eval_batch_size 32 
--do_lower_case 
--data_dir data/SST-2/add_1/ # add_1 is the dataset where we use "add character" method to attack the instance and only one word was attacked.
--data_file data/SST-2/add_1/test.tsv 
--bert_model bert-base-uncased   
--max_seq_length 128  
--train_batch_size 16   
--learning_rate 2e-5   
--num_eval_epochs 5 
--output_dir models/
--single  

Then, we recover the words with a pre-trained embedding estimator. Note that we use small-world-graph to conduct a KNN-based search for closest word in the embedding space.

python bert_generator.py 
--task_name sst-2 
--do_eval  
--do_lower_case   
--data_dir data/SST-2/add_1/  
--bert_model bert-base-uncased   
--max_seq_length 64   
--train_batch_size 8  
--learning_rate 2e-5   
--output_dir ./tmp/sst2-gnrt/ 
--num_eval_epochs 2

After recovering the test instances, we can run a model to check the recovering effectiveness. The model in our settings is a sentiment classification model based on bert contextualized embeddings.