Home

Awesome

GST

This is the implementation for the paper

Learning Sparse Interaction Graphs of Partially Detected Pedestrians for Trajectory Prediction

Zhe Huang, Ruohua Li, Kazuki Shin, Katherine Driggs-Campbell

published in RA-L.

[Paper] [arXiv] [Project]

GST is the abbreviation of our model Gumbel Social Transformer. All code was developed and tested on Ubuntu 18.04 with CUDA 10.2, Python 3.6.9, and PyTorch 1.7.1. <br/>

Citation

If you find this repo useful, please cite

@article{huang2022learning,
  title={Learning Sparse Interaction Graphs of Partially Detected Pedestrians for Trajectory Prediction},
  author={Huang, Zhe and Li, Ruohua and Shin, Kazuki and Driggs-Campbell, Katherine},
  journal={IEEE Robotics and Automation Letters},
  year={2022},
  volume={7},
  number={2},
  pages={1198-1205},
  doi={10.1109/LRA.2021.3138547}
}

Setup

1. Create a Virtual Environment. (Optional)
virtualenv -p /usr/bin/python3 myenv
source myenv/bin/activate
2. Install Packages

You can run either <br/>

pip install -r requirements.txt

or <br/>

pip install numpy
pip install scipy
pip install matplotlib
pip install tensorboardX
pip install torch==1.7.1

If you want to use tensorboard --logdir results to check training curves, install tensorflow by running

pip install tensorflow
3. Create Folders and Dataset Files.
sh run/make_dirs.sh
sh run/create_datasets.sh

Training and Evaluation on Various Configurations

To train and evaluate a model with n=1, i.e., the target pedestrian pays attention to at most one partially observed pedestrian, run

sh run/train_sparse.sh
sh run/eval_sparse.sh

To train and evaluate a model with n=1 and temporal component as a temporal convolution network, run

sh run/train_sparse_tcn.sh
sh run/eval_sparse_tcn.sh

To train and evaluate a model with full connection, i.e., the target pedestrian pays attention to all partially observed pedestrians in the scene, run

sh run/train_full_connection.sh
sh run/eval_full_connection.sh

To train and evaluate a model in which the target pedestrian pays attention to all fully observed pedestrians in the scene, run

sh run/train_full_connection_fully_observed.sh
sh run/eval_full_connection_fully_observed.sh

Important Arguments for Building Customized Configurations

Credits

Part of the code is based on the following works and repos:

[1] Mohamed, Abduallah, et al. "Social-stgcnn: A social spatio-temporal graph convolutional neural network for human trajectory prediction." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020. [GitHub]

[2] Pytorch implementation of Multi-head Attention. [Modules] [Functional]

Contact

Please feel free to open an issue or send an email to zheh4@illinois.edu.