Home

Awesome

ST-SSL: Spatio-Temporal Self-Supervised Learning for Traffic Prediction

PWC

PWC

PWC

PWC

This is a Pytorch implementation of ST-SSL in the following paper:

framework

new 27/10/2023: This paper is picked up by leading WeChat official accounts in the field of data mining and transportation. 当交通遇上机器学习 | 时空实验室 | AI蜗牛车

new 22/04/2023: The post of this paper is selected for a headline tweet by PaperWeekly and received nearly 7,000 reads. PaperWeekly is a leading AI academic platform in China.

new 09/02/2023: The video replay of academic presentation at AAAI 2023.

new 04/02/2023: J. Ji is invited to give a talk at AAAI 2023 Beijing Pre-Conference. The talk is about Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction.

Requirement

We build this project by Python 3.8 with the following packages:

numpy==1.21.2
pandas==1.3.5
PyYAML==6.0
torch==1.10.1

Datasets

The datasets range from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}. You can download them from GitHub repo, Beihang Cloud Drive, or Google Drive.

Each dataset is composed of 4 files, namely train.npz, val.npz, test.npz, and adj_mx.npz.

|----NYCBike1\
|    |----train.npz    # training data
|    |----adj_mx.npz   # predefined graph structure
|    |----test.npz     # test data
|    |----val.npz      # validation data

The train/val/test data is composed of 4 numpy.ndarray objects:

For all datasets, previous 2-hour flows as well as previous 3-day flows around the predicted time are used to forecast flows for the next time step.

adj_mx.npz is the graph adjacency matrix that indicates the spatial relation of every two regions/nodes in the studied area.

⚠️ Note that all datasets are processed as a sliding window view. Raw data of NYCBike1 and BJTaxi are collected from STResNet. Raw data of NYCBike2 and NYCTaxi are collected from STDN. If needed, one can download the original datasets from this link.

Model training and Evaluation

If the environment is ready, please run the following commands to train the model on the specific dataset from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}.

>> cd ST-SSL
>> ./runme 0 NYCBike1   # 0 specifies the GPU id, NYCBike1 gives the dataset

Note that this repo only contains the NYCBike1 data because including all datasets can make this repo heavy.

Cite

If you find the paper useful, please cite the following:

@article{ji2023spatio, 
  title={Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction}, 
  author={Ji, Jiahao and Wang, Jingyuan and Huang, Chao and Wu, Junjie and Xu, Boren and Wu, Zhenhe and Zhang Junbo and Zheng, Yu}, 
  journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 
  volume={37},
  number={4},
  pages={4356-4364},
  year={2023}
}