Home

Awesome

Implementation of Structured-Self-Attentive-Sentence-Embedding

Introduction

This is an implementation of the paper A Structured Self-Attentive Sentence Embedding,using Mxnet/Gluon. This program implements most of the details in the paper. Finally, the experiment was carried out on the user review star classification task mentioned in the original paper, and used the same data set: The reviews of Yelp Data. The model structure is as follows:

Bi_LSTM_Attention

Requirments

  1. Mxnet
  2. Gluon NLP
  3. Numpy
  4. Scikit-Learn
  5. Python3

Implemented

1. Attention mechanism proposed in the original paper

$$ A = softmax(W_{s2}tanh(W_{s1}H^T)) $$

2. Punishment constraints to ensure diversity of attention

$$ P = ||(AA^T-I)||_F^2 $$

3. Parameter pruning proposed in the appendix of the paper

prune weights

4. Gradient clip and learning rate decay.

5. SoftmaxCrossEntropy with category weights

For sentiment classification

1. Training parameter description

parser.add_argument('--nword_dims', type=int, default=300,
                     help='size of word embeddings')
parser.add_argument('--nhiddens', type=int, default=300,
                     help='number of hidden units per layer')
parser.add_argument('--nlayers', type=int, default=1,
                  help='number of layers in BiLSTM')
parser.add_argument('--natt_units', type=int, default=350,
                  help='number of attention unit')
parser.add_argument('--natt_hops', type=int, default=1,
                  help='number of attention hops, for multi-hop attention model')
parser.add_argument('--nfc', type=int, default=512,
                  help='hidden (fully connected) layer size for classifier MLP')
parser.add_argument('--pool_way', type=str, choice=['flatten', 'mean', 'prune'],
                  default='flatten', help='pool att output way')
parser.add_argument('--nprune_p', type=int, default=None, help='prune p size')
parser.add_argument('--nprune_q', type=int, default=None, help='prune q size')
parser.add_argument('--nclass', type=int, default=5, help='number of classes')
parser.add_argument('--wv_name', type=str, choices={'glove', 'w2v', 'fasttext', 'random'},
                  default='random', help='word embedding way')

parser.add_argument('--drop_prob', type=float, default=0.5,
                  help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--clip', type=float, default=0.5,
                  help='clip to prevent the too large grad in LSTM')
parser.add_argument('--lr', type=float, default=.001, help='initial learning rate')
parser.add_argument('--nepochs', type=int, default=10, help='upper epoch limit')
parser.add_argument('--loss_name', type=str, choice=['sce, wsce'],
                  default='sce', help='loss function name')
parser.add_argument('--freeze_embedding', default=False, action='store_true')
parser.add_argument('--seed', type=int, default=2018, help='random seed')
parser.add_argument('--batch_size', type=int, default=64, help='batch size for training')
parser.add_argument('--optimizer', type=str, default='Adam', help='type of optimizer')
parser.add_argument('--penalization_coeff', type=float, default=0.1,
                  help='the penalization coefficient')
parser.add_argument('--lr_decay_step', type=int, default=2,
                  help='step of learning rate decay')
parser.add_argument('--lr_decay_rate', type=float, default=0.1,
                  help='rate of learning rate decay')
parser.add_argument('--log_interval', type=int, default=400,
                  help='interval steps of log output')

parser.add_argument('--valid_rate', type=float, default=0.1,
                  help='proportion of validation set samples')
parser.add_argument('--max_seq_len', type=int, default=100,
                  help='max length of every sample')
parser.add_argument('--model_root', type=str, default='../models',
                  help='path to save the final model')
parser.add_argument('--model_name', type=str, default='self_att_bilstm_model',
                  help='path to save the final model')
parser.add_argument('--data_json_path', type=str,
                  default='../data/sub_review_labels.json', help='raw data path')
parser.add_argument('--formated_data_path', type=str,
                  default='../data/formated_data.pkl', help='formated data path')

2. Training details

The original paper uses 500K data as the training set, 2000 data as the validation set, and 2000 as the test set. Due to personal machine restrictions, 200 K is randomly selected as the training set and 2000 data is used as the validation set in the case of ensuring the data distribution and the original data. The weight of the WeightedSoftmaxCrossEntropy is set according to the proportion of the data category. If the data is different and needs to be used To use this loss function, you need to modify the value of the set class_weight yourself.

Training usage (parameters can be customized):

python train_model.py --nlayers 1 --nepochs 5 --natt_hops 2 --loss_name sce

Reference

  1. A Structured Self-Attentive Sentence Embedding

  2. The reviews of Yelp Data