Awesome
ST-SSL: Spatio-Temporal Self-Supervised Learning for Traffic Prediction
This is a Pytorch implementation of ST-SSL in the following paper:
- J. Ji, J. Wang, C. Huang, et al. "Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction". in AAAI 2023.
27/10/2023: This paper is picked up by leading WeChat official accounts in the field of data mining and transportation. 当交通遇上机器学习 | 时空实验室 | AI蜗牛车
22/04/2023: The post of this paper is selected for a headline tweet by PaperWeekly and received nearly 7,000 reads. PaperWeekly is a leading AI academic platform in China.
09/02/2023: The video replay of academic presentation at AAAI 2023.
04/02/2023: J. Ji is invited to give a talk at AAAI 2023 Beijing Pre-Conference. The talk is about Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction.
Requirement
We build this project by Python 3.8 with the following packages:
numpy==1.21.2
pandas==1.3.5
PyYAML==6.0
torch==1.10.1
Datasets
The datasets range from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}
. You can download them from GitHub repo, Beihang Cloud Drive, or Google Drive.
Each dataset is composed of 4 files, namely train.npz
, val.npz
, test.npz
, and adj_mx.npz
.
|----NYCBike1\
| |----train.npz # training data
| |----adj_mx.npz # predefined graph structure
| |----test.npz # test data
| |----val.npz # validation data
The train/val/test
data is composed of 4 numpy.ndarray
objects:
X
: input data. It is a 4D tensor of shape(#samples, #lookback_window, #nodes, #flow_types)
, where#
denotes the number sign.Y
: data to be predicted. It is a 4D tensor of shape(#samples, #predict_horizon, #nodes, #flow_types)
. Note thatX
andY
are paired in the sample dimension. For instance,(X_i, Y_i)
is thei
-the data sample withi
indexing the sample dimension.X_offset
: a list indicating offsets ofX
's lookback window relative to the current time with offset0
.Y_offset
: a list indicating offsets ofY
's prediction horizon relative to the current time with offset0
.
For all datasets, previous 2-hour flows as well as previous 3-day flows around the predicted time are used to forecast flows for the next time step.
adj_mx.npz
is the graph adjacency matrix that indicates the spatial relation of every two regions/nodes in the studied area.
⚠️ Note that all datasets are processed as a sliding window view. Raw data of NYCBike1 and BJTaxi are collected from STResNet. Raw data of NYCBike2 and NYCTaxi are collected from STDN. If needed, one can download the original datasets from this link.
Model training and Evaluation
If the environment is ready, please run the following commands to train the model on the specific dataset from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}
.
>> cd ST-SSL
>> ./runme 0 NYCBike1 # 0 specifies the GPU id, NYCBike1 gives the dataset
Note that this repo only contains the NYCBike1 data because including all datasets can make this repo heavy.
Cite
If you find the paper useful, please cite the following:
@article{ji2023spatio,
title={Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction},
author={Ji, Jiahao and Wang, Jingyuan and Huang, Chao and Wu, Junjie and Xu, Boren and Wu, Zhenhe and Zhang Junbo and Zheng, Yu},
journal={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={37},
number={4},
pages={4356-4364},
year={2023}
}