Home

Awesome

Gated-Transformer-on-MTS

基于Pytorch,使用改良的Transformer模型应用于多维时间序列的分类任务上

实验结果

对比模型选择 Fully Convolutional Networks (FCN) and Residual Net works (ResNet) <br>

DataSetMLPFCNResNetEncoderMCNNt-LeNetMCDCNNTime-CNNTWIESNGated Transformer
ArabicDigits96.999.499.698.110.010.095.995.885.398.8
AUSLAN93.397.597.493.81.11.185.472.672.497.5
CharacterTrajectories96.999.099.097.15.46.793.896.092.097.0
CMUsubject1660.010099.798.353.151.051.497.689.3100
ECG74.887.286.787.267.067.050.084.173.791.0
JapaneseVowels97.699.399.297.69.223.894.495.696.5
Libras78.096.495.478.36.76.765.163.779.488.9
UWave90.193.492.690.812.512.584.58.975.491.0
KickvsPunch61.054.051.061.054.050.056.062.067.090.0
NetFlow55.089.162.777.777.972.363.089.094.5100
PEMS---------93.6
Wafer89.498.298.998.689.489.465.894.894.999.1
WalkvsRun70.010010010075.060.045.0100.094.4100

实验环境

环境描述
语言Python3.7
框架Pytorch1.6
IDEPycharm and Colab
设备CPU and GPU

数据集

多元时间序列数据集, 文件为.mat格式,训练集与测试集在一个文件中,且预先定义为了测试集数据,测试集标签,训练集数据与训练集标签。 <br> 数据集下载使用百度云盘,连接如下:<br> 链接:https://pan.baidu.com/s/1u2HN6tfygcQvzuEK5XBa2A <br> 提取码:dxq6 <br> Google drive link:https://drive.google.com/drive/folders/1QFadJOmbOLWMjLrcebZQR_w2fBX7x0Vm?usp=share_link

UEA and UCR dataset:http://www.timeseriesclassification.com/index.php


数据集维度描述

DataSetNumber of ClassesSize of training SetSize of testing SetMax Time series LengthChannel
ArabicDigits10660022009313
AUSLAN951140142513622
CharacterTrajectories2030025582053
CMUsubject162292958062
ECG21001001522
JapaneseVowels92703702912
Libras15180180452
UWave820042783153
KickvsPunch2161084162
NetFlow28035349974
PEMS7267173144963
Wafer22988961986
WalkvsRun22816191862

数据预处理

详细数据集处理过程参看 dataset_process.py文件。<br>

模型描述

<img src="https://github.com/SY-Ma/Gated-Transformer-on-MTS/blob/main/images/GTN%20structure.png" style="zoom:50%">

超参描述

超参描述
d_model模型处理的为时间序列而非自然语言,所以省略了NLP中对词语的编码,仅使用一个线性层映射成d_model维的稠密向量,此外,d_model保证了在每个模块衔接的地方的维度相同
d_hiddenPosition-wise FeedForword 中隐藏层的维度
d_input时间序列长度,其实是一个数据集中最长时间步的维度 固定的,直接由数据集预处理决定
d_channel多元时间序列的时间通道数,即是几维的时间序列 固定的,直接由数据集预处理决定
d_output分类类别数 固定的,直接由数据集预处理决定
q,vMulti-Head Attention中线性层映射维度
hMulti-Head Attention中头的数量
NEncoder栈中Encoder的数量
dropout随机失活
EPOCH训练迭代次数
BATCH_SIZEmini-batch size
LR学习率 定义为1e-4
optimizer_name优化器选择 建议Adagrad和Adam

文件描述

文件名称描述
dataset_process数据集处理
font存储字体,用于结果图中的文字
gather_figure聚类结果图
heatmap_figure_in_test测试模型时绘制的score矩阵的热力图
module模型的各个模块
mytest各种测试代码
reslut_figure准确率结果图
saved_model保存的pkl文件
utils工具类文件
run.py训练模型
run_with_saved_model.py使用训练好的模型(保存为pkl文件)测试结果

utils工具描述

简单介绍几个

Tips

参考

[Wang et al., 2017] Z. Wang, W. Yan, and T. Oates. Time series classification from scratch with deep neural networks:A strong baseline. In 2017 International Joint Conference on Neural Networks (IJCNN), pages 1578–1585, 2017.

本人学识浅薄,代码和文字若有不当之处欢迎批评与指正!

联系方式:masiyuan007@qq.com