Home

Awesome

Text classification demos

Tensorflow 环境下,不同的神经网络模型对中文文本进行分类,本文中的 demo 都是字符级别的文本分类(增加了word-based 的统计结果),简化了文本分类的流程,字符级别的分类在有些任务上的效果可能不好,需要结合实际情况添加自定义的分词模块。

数据集

下载地址: https://pan.baidu.com/s/1hugrfRu 密码: qfud

使用 THUCNews 的一个子集进行训练与测试,使用了其中的 10 个分类,每个分类 6500 条数据。

类别如下:

体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐

数据集划分如下:

训练集: 5000 * 10
验证集: 500 * 10
测试集: 1000 * 10

具体介绍请参考:text-classification-cnn-rnn

分类效果

modelfasttextcnnrnnrcnnhandpcnnbert
val_acc92.9293.5693.5694.3693.9493.7097.84
test_acc93.1594.5794.3795.5393.6594.8796.93
modelfasttextcnnrnnrcnnhandpcnnbert
val_acc95.5295.2893.1095.6095.1095.68-
test_acc95.3495.7794.0596.3695.6695.97-

模型介绍

1、FastText

fasttext_model.py 文件为训练和测试 fasttext 模型的代码

图1 FastText 模型结构图

本代码简化了 fasttext 模型的结构,模型结构非常简单,运行速度简直飞快,模型准确率也不错,可根据实际需要优化模型结构

2、TextCNN

cnn_model.py 文件为训练和测试 TextCNN 模型的代码

图2 TextCNN 模型结构图

本代码实现了 TextCNN 模型的结构,通过 3 个不同大小的卷积核,对输入文本进一维卷积,分别 pooling 三个卷积之后的 feature, 拼接到一起,然后进行 dense 操作,最终输出模型结果。可实现速度和精度之间较好的折中。

3、RNN

rnn_model.py 文件为训练和测试 TextCNN 模型的代码

图8 TextRNN 模型结构图

本代码实现了 TextRNN 模型的结构,对输入序列进行embedding,然后输入两层的 rnn_cell中学习序列特征,取最后一个 word 的 state 作为进行后续的 fc 操作,最终输出模型结果。

4、RCNN

rcnn_model.py 文件为训练和测试 RCNN 模型的代码

图3 RCNN 模型结构图

Recurrent Convolutional Neural Network for Text Classification, 在学习 word representations 时候,同时采用了 rnn 结构来学习 word 的上下文,虽然模型名称为 RCNN,但并没有显式的存在卷积操作。

1、采用双向lstm学习 word 的上下文

   c_left = tf.concat([tf.zeros(shape), output_fw[:, :-1]], axis=1, name="context_left")
   c_right = tf.concat([output_bw[:, 1:], tf.zeros(shape)], axis=1, name="context_right")
   word_representation = tf.concat([c_left, embedding_inputs, c_right], axis=2, name="last")

2、pooling + softmax

word_representation 的维度是 batch_size * seq_length * 2 * context_dim + embedding_dim

在 seq_length 维度进行 max pooling,然后进行 fc 操作就可以进行分类了,可以将该网络看成是 fasttext 的改进版本

5、HAN

han_model.py 文件为训练和测试 HAN 模型的代码

图4 HAN 模型结构图

HAN 为 Hierarchical Attention Networks,将待分类文本,分为一定数量的句子,分别在 word level 和 sentence level 进行 encoder 和 attention 操作,从而实现对较长文本的分类。

本文是按照句子长度将文本分句的,实际操作中可按照标点符号等进行分句,理论上效果能好一点。

图4 attention

(1) 将输入的lstm编码结果做一次非线性变换,可以看做是输入编码的hidden representation, shape = (64*10) * 60 * 256

(2) 将 hidden representation 与一个学习得到的 word level context vector 的相似性进行 softmax,得到每个单词在句子中的权重

(3) 对输入的lstm 编码进行加权求和,得到句子的向量表示

数据维度变化:(64*10) * 60 * 512 --- (64*10) * 512

6、DPCNN

dpcnn_model.py 文件为训练和测试 DPCNN 模型的代码

图5 DPCNN 模型结构图

DPCNN 通过卷积和残差连接增加了以往用于文本分类 CNN 网络的深度,可以有效提取文本中的远程关系特征,并且复杂度不高,实验表名,效果比以往的 CNN 结构要好一点。

7、BERT

bert_model.py 文件为训练和测试 BERT 模型的代码

google官方提供用于文本分类的demo写的比较抽象,所以本文基于 google 提供的代码和初始化模型,重写了文本分类模型的训练和测试代码,bert 分类模型在小数据集下效果很好,通过较少的迭代次数就能得到很好的效果,但是训练和测试速度较慢,这点不如基于 CNN 的网络结构。

bert_model.py 将训练数据和验证数据存储为 tfrecord 文件,然后进行训练

由于 bert 提供的预训练模型较大,需要自己去 google-research/bert 中下载预训练好的模型,本实验采用的是 "BERT-Base, Chinese" 模型。

图6 BERT 输入数据格式

图7 BERT 下游任务介绍

参考