Home

Awesome

Text-AutoAugment (TAA)

This repository contains the code for our paper Text AutoAugment: Learning Compositional Augmentation Policy for Text Classification (EMNLP 2021 main conference).

Overview of IAIS

Updates

Quick Links

Overview

  1. We present a learnable and compositional framework for data augmentation. Our proposed algorithm automatically searches for the optimal compositional policy, which improves the diversity and quality of augmented samples.

  2. In low-resource and class-imbalanced regimes of six benchmark datasets, TAA significantly improves the generalization ability of deep neural networks like BERT and effectively boosts text classification performance.

Getting Started

Prepare environment

Install pytorch and other small additional dependencies. Then, install this repo as a python package. Note that cudatoolkit=10.2 should match the CUDA version on your machine.

# Clone this repo
git clone https://github.com/lancopku/text-autoaugment.git
cd text-autoaugment

# Create a conda environment
conda create -n taa python=3.6
conda activate taa

# Install dependencies
pip install torch==1.10.1+cu102 -f https://download.pytorch.org/whl/cu102/torch_stable.html
pip install git+https://github.com/wbaek/theconf
pip install git+https://github.com/ildoonet/pystopwatch2.git
pip install -r requirements.txt

# Install this library (**no need to re-build if the source code is modified**)
python setup.py develop

# Download the models in NLTK
python -c "import nltk; nltk.download('wordnet'); nltk.download('averaged_perceptron_tagger'); nltk.download('omw-1.4')"

Please make sure your Torch supports GPU, check it with the command python -c "import torch; print(torch.cuda.is_available())" (should output True).

Use TAA with Huggingface

1. Get augmented training dataset with TAA policy

<details> <summary><b>Option 1: Search for the optimal policy</b></summary>

You can search for the optimal policy on classification datasets supported by huggingface/datasets:

from taa.search_and_augment import search_and_augment

# return the augmented train dataset in the form of torch.utils.data.Dataset
augmented_train_dataset = search_and_augment(configfile="/path/to/your/config.yaml")

The configfile (YAML file) contains all the arguments including path, model, dataset, optimization hyper-parameter, etc. To successfully run the code, please carefully preset these arguments:

<details> <summary>show details</summary> </details>

configfile example 1: TAA for huggingface dataset

bert_sst2_example.yaml is a configfile example for BERT model and SST2 dataset. You can follow this example to create your own configfile for other huggingface dataset.

For instance, if you only want to change the dataset from sst2 to imdb, just delete the sst2 in the 'path' argument, modify the 'name' to imdb and modity the 'text_key' to text. The result should be like bert_imdb_example.yaml.

configfile example 2: TAA for custom (local) dataset

bert_custom_data_example.yaml is a configfile example for BERT model and custom (local) dataset. The custom dataset should be in the CSV format, and the column name of the data table should be text and label. custom_data.csv is an example of the custom dataset.

WARNING: The policy optimization framework is based on ray. By default we use 4 GPUs and 40 CPUs for policy optimization. Make sure your computing resources meet this condition, or you will need to create a new configuration file. And please specify the gpus, e.g., CUDA_VISIBLE_DEVICES=0,1,2,3 before using the above code. TPU does not seem to be supported now.

</details> <details> <summary><b>Option 2: Use our pre-searched policy</b></summary>

To train a model on the datasets augmented by our pre-searched policy, please use (Take IMDB as an example):

from taa.search_and_augment import augment_with_presearched_policy

# return the augmented train dataset in the form of torch.utils.data.Dataset
augmented_train_dataset = augment_with_presearched_policy(configfile="/path/to/your/config.yaml")

Now we support IMDB, SST5, TREC, YELP2 and YELP5. See archive.py for details.

This table lists the test accuracy (%) of pre-searched TAA policy on full datasets:

DatasetIMDBSST-5TRECYELP-2YELP-5
No Aug88.7752.2996.4095.8565.55
TAA89.3752.5597.0796.0465.73
n_aug44422

More pre-searched policies and their performance will be COMING SOON.

</details>

2. Fine-tune a new model on the augmented training dataset

After getting augmented_train_dataset, you can load it to the huggingface trainer directly. Please refer to search_augment_train.py for details.

Reproduce results in the paper

Please see examples/reproduce_experiment.py, and run script/huggingface_lowresource.sh or script/huggingface_imbalanced.sh.

Contact

If you have any questions related to the code or the paper, feel free to open an issue.

Acknowledgments

Code refers to: fast-autoaugment.

Citation

If you find this code useful for your research, please consider citing:

@inproceedings{ren2021taa,
    title = "Text {A}uto{A}ugment: Learning Compositional Augmentation Policy for Text Classification",
    author = "Ren, Shuhuai and Zhang, Jinchao and Li, Lei and Sun, Xu and Zhou, Jie",
    booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
    year = "2021",
}

License

MIT