Home

Awesome

CRAT-Pred: Vehicle Trajectory Prediction with Crystal Graph Convolutional Neural Networks and Multi-Head Self-Attention

Official repository of the paper:
CRAT-Pred: Vehicle Trajectory Prediction with Crystal Graph Convolutional Neural Networks and Multi-Head Self-Attention
Julian Schmidt, Julian Jordan, Franz Gritschneder and Klaus Dietmayer
Accepted at 2022 IEEE International Conference on Robotics and Automation (ICRA)

img

Citation

If you use our source code, please cite:

@InProceedings{schmidt2022cratpred,
  author={Julian Schmidt and Julian Jordan and Franz Gritschneder and Klaus Dietmayer},
  booktitle={2022 IEEE International Conference on Robotics and Automation (ICRA)}, 
  title={CRAT-Pred: Vehicle Trajectory Prediction with Crystal Graph Convolutional Neural Networks and Multi-Head Self-Attention}, 
  year={2022},
  pages={7799--7805},}

License

<a rel="license" href="http://creativecommons.org/licenses/by-nc/4.0/"> <img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc/4.0/88x31.png" /></a><br />CRAT-Pred is licensed under <a rel="license" href="http://creativecommons.org/licenses/by-nc/4.0/" >Creative Commons Attribution-NonCommercial 4.0 International License</a>.

Check LICENSE for more information.

Installation

Install Anaconda

We recommend using Anaconda. The installation is described on the following page:
https://docs.anaconda.com/anaconda/install/linux/

Install Required Packages

conda env create -f environment.yml

Activate Environment

conda activate crat-pred

Install Argoverse API

pip install git+https://github.com/argoai/argoverse-api.git

Setup Argoverse Dataset

Download and Extract Dataset

bash fetch_dataset.sh

Preprocess the Dataset

Online and offline preprocessing is implemented. If you want to train your model offline on the preprocessed dataset, run:

python3 preprocess.py

You can also skip this step and run the preprocessing online during training.

Train Model

python3 train.py

or

python3 train.py --use_preprocessed=True

Checkpoints are saved in the lightning_logs/ folder. For accessing metrics and losses via Tensorboard, first start the server:

tensorboard --logdir lightning_logs/

Navigating to http://localhost:6006/ opens Tensorboard.

Test Model on Validation Set

python3 test.py --weight=/path/to/checkpoint.ckpt

Generate Predictions on Test Set

python3 test.py --weight=/path/to/checkpoint.ckpt --split=test