Awesome
CRAT-Pred: Vehicle Trajectory Prediction with Crystal Graph Convolutional Neural Networks and Multi-Head Self-Attention
Official repository of the paper:
CRAT-Pred: Vehicle Trajectory Prediction with Crystal Graph Convolutional Neural Networks and Multi-Head Self-Attention
Julian Schmidt, Julian Jordan, Franz Gritschneder and Klaus Dietmayer
Accepted at 2022 IEEE International Conference on Robotics and Automation (ICRA)
Citation
If you use our source code, please cite:
@InProceedings{schmidt2022cratpred,
author={Julian Schmidt and Julian Jordan and Franz Gritschneder and Klaus Dietmayer},
booktitle={2022 IEEE International Conference on Robotics and Automation (ICRA)},
title={CRAT-Pred: Vehicle Trajectory Prediction with Crystal Graph Convolutional Neural Networks and Multi-Head Self-Attention},
year={2022},
pages={7799--7805},}
License
<a rel="license" href="http://creativecommons.org/licenses/by-nc/4.0/"> <img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc/4.0/88x31.png" /></a><br />CRAT-Pred is licensed under <a rel="license" href="http://creativecommons.org/licenses/by-nc/4.0/" >Creative Commons Attribution-NonCommercial 4.0 International License</a>.Check LICENSE for more information.
Installation
Install Anaconda
We recommend using Anaconda.
The installation is described on the following page:
https://docs.anaconda.com/anaconda/install/linux/
Install Required Packages
conda env create -f environment.yml
Activate Environment
conda activate crat-pred
Install Argoverse API
pip install git+https://github.com/argoai/argoverse-api.git
Setup Argoverse Dataset
Download and Extract Dataset
bash fetch_dataset.sh
Preprocess the Dataset
Online and offline preprocessing is implemented. If you want to train your model offline on the preprocessed dataset, run:
python3 preprocess.py
You can also skip this step and run the preprocessing online during training.
Train Model
python3 train.py
or
python3 train.py --use_preprocessed=True
Checkpoints are saved in the lightning_logs/
folder.
For accessing metrics and losses via Tensorboard, first start the server:
tensorboard --logdir lightning_logs/
Navigating to http://localhost:6006/ opens Tensorboard.
Test Model on Validation Set
python3 test.py --weight=/path/to/checkpoint.ckpt
Generate Predictions on Test Set
python3 test.py --weight=/path/to/checkpoint.ckpt --split=test