Home

Awesome

SpoT-Mamba: Learning Long-Range Dependency on Spatio-Temporal Graphs with Selective State Spaces

This code is the official implementation of the following paper:

Jinhyeok Choi, Heehyeon Kim, Minhyeong An, and Joyce Jiyoung Whang, SpoT-Mamba: Learning Long-Range Dependency on Spatio-Temporal Graphs with Selective State Spaces, Spatio-Temporal Reasoning and Learning (STRL) Workshop at the 33rd International Joint Conference on Artificial Intelligence (IJCAI 2024), 2024

All codes are written by Jinhyeok Choi (cjh0507@kaist.ac.kr). When you use this code, please cite our paper.

@article{spotmamba,
  author={Jinhyeok Choi and Heehyeon Kim and Minhyeong An and Joyce Jiyoung Whang},
  title={{S}po{T}-{M}amba: Learning Long-Range Dependency on Spatio-Temporal Graphs with Selective State Spaces},
  year={2024},
  journal={arXiv preprint arXiv.2406.11244},
  doi = {10.48550/arXiv.2406.11244}
}

Requirments

We used Python 3.8, Pytorch 1.13.1, and DGL 1.1.2 with cudatoolkit 11.7.

We also used the official implementation of Mamba (mamba-ssm 1.2.0.post1).

For installation instructions of Mamba, please refer to the official repository.

Usage

SpoT-Mamba

We used NVIDIA GeForce RTX 3090 24GB for all our experiments. We provide the template configuration file (template.json).

To train SpoT-Mamba, use the run.py file as follows:

python run.py --config_path=./template.json

Results will be printed in the terminal and saved in the directory according to the configuration file.

You can find log files and checkpoints resulting from experiments in the f"experimental_results/{dataset}-{in_steps}-{out_steps}-{str(train_ratio).zfill(2)}-{seed}-{model}" directory.

Training from Scratch

To train SpoT-Mamba from scratch, run run.py with the configuration file. Please refer to modules/experiment_handler.py, modules/data_handler.py, and models/models.py for examples of the arguments in the configuration file.

The list of arguments of the configuration file:

{
    "setting": {
        "exp_name": "Name of the experiment.",
        "dataset": "The dataset to be used, e.g., 'pems04'.",
        "model": "The model type to be used, e.g., 'SpoTMamba'.",
        "in_steps": "Number of input time steps to use in the model.",
        "out_steps": "Number of output time steps (predictions) the model should generate.",
        "train_ratio": "Percentage of data to be used for training (expressed as an integer out of 100).",
        "val_ratio": "Percentage of data to be used for validation (expressed as an integer out of 100).",
        "seed": "Random seed for the reproducibility of results."
    },
    "hyperparameter": {
        "model": {
            "emb_dim": "Dimension of each embedding.",
            "ff_dim": "Dimension of the feedforward network within the model.",
            "num_walks": "Number of random walks to perform (M).",
            "len_walk": "Length of each random walk (K).",
            "num_layers": "Number of Mamba blocks / Number of layers in the Transformer encoder.",
            "dropout": "Dropout rate used in the model."
        },
        "training": {
            "lr_decay_rate": "Decay rate for learning rate.",
            "milestones": [
                "Epochs after which the learning rate will decay."
            ],
            "epochs": "Total number of training epochs.",
            "valid_epoch": "Number of epochs between each validation.",
            "patience": "Number of epochs to wait before early stopping if no progress on the validation set.",
            "batch_size": "Size of the batches used during training.",
            "lr": "Initial learning rate for training.",
            "weight_decay": "Weight decay rate used for regularization during training."
        }
    },
    "cuda_id": "CUDA device ID (GPU ID) to be used for training if available.",
    "force_retrain": "Flag to force the retraining of the model even if a trained model exists."
}

Hyperparameters

We tuned SpoT-Mamba with the following tuning ranges:

Description for each file

./

./datasets

./models

./modules

trained_models

utils