Awesome
DRIVE (Deep ReInforced Accident Anticipation with Visual Explanation)
Project | Paper & Supp | Demo
International Conference on Computer Vision (ICCV), 2021.
Table of Contents
Introduction
We propose the DRIVE model which uses Deep Reinforcement Learning (DRL) to solve explainable traffic accident anticipation problem. This model simulates both the bottom-up and top-down visual attention mechanisms in a dashcam observation environment so that the decision from the proposed stochastic multi-task agent can be visually explained by attentive regions. Moreover, the proposed dense anticipation reward and sparse fixation reward are effective in training the DRIVE model with the improved Soft Actor Critic DRL algorithm.
Installation
Note: This repo is developed using pytorch 1.4.0
in Ubuntu 18.04 LTS OS with CUDA 10.1
GPU environment. However, more recent pytorch and CUDA versions are also compatible with this repo, such as pytorch 1.7.1
and CUDA 11.3
.
a. Create a conda virtual environment of this repo, and activate it:
conda create -n pyRL python=3.7 -y
conda activate pyRL
b. Install official pytorch. Take the pytorch==1.4.0
as an example:
conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch
c. Install the rest dependencies.
pip install -r requirements.txt
Datasets
This repo currently supports for the down-sized version of DADA-2000 dataset. Specifically, we reduced the image size at a half and trimmed the videos into accident clips with at most 450 frames. For more details, please refer to the code script data/reduce_data.py
.
We also noticed that the original DATA-2000 dataset was updated, here we provide the processed DADA-2000-small.zip for your convenience. Simply download and unzip it into data
folder:
cd data
unzip DADA-2000-small.zip ./
Testing
a. Download the pretrained saliency models.
The pre-trained saliency models are provided here: saliency_models, where the mlnet_25.pth
is used by default in this repo. Please place the file to the path models/saliency/mlnet_25.pth
.
b. Download the pre-trained DRIVE model:
The pre-trained DRIVE model is provided here: DADA2KS_Full_SACAE_Final, and place the model file to the path output/DADA2KS_Full_SACAE_Final/checkpoints/sac_epoch_50.pt
.
c. Run the DRIVE testing.
bash script_RL.sh test 0 4 DADA2KS_Full_SACAE_Final
Wait for a while, results will be reported.
Training
This repo suports for training DRIVE models based on two DRL algorithms, i.e., REINFORCE and SAC, and two kinds of visual saliency features, i.e., MLNet and TASED-Net. By default, we use SAC + MLNet to achieve the best speed and accuracy trade-off.
a. Download the pretrained saliency models.
The pre-trained saliency models are provided here: saliency_models, where the mlnet_25.pth
is used by default in this repo. Please place the file to the path models/saliency/mlnet_25.pth
.
b. Run the DRIVE training.
bash script_RL.sh train 0 4 DADA2KS_Full_SACAE_Final
c. Monitoring the training on Tensorboard.
Visualizing the training curves (losses, accuracies, etc.) on TensorBoard by the following commands:
cd output/DADA2KS_Full_SACAE_Final/tensorboard
tensorboard --logdir=./ --port 6008
Then, you will see the generated url address http://localhost:6008
. Open this address with your Internet Browser (such as Chrome), you will monitoring the status of training.
TIPs:
If you are using SSH connection to a remote server without monitor, tensorboard visualization can be done on your local machine by manually mapping the SSH port number:
ssh -L 16008:localhost:6008 {your_remote_name}@{your_remote_ip}
Then, you can monitor the tensorboard by the port number 16008
by typing http://localhost:16008
in your browser.
Citation
If you find the code useful in your research, please cite:
@inproceedings{BaoICCV2021DRIVE,
author = "Bao, Wentao and Yu, Qi and Kong, Yu",
title = "Deep Reinforced Accident Anticipation with Visual Explanation",
booktitle = "International Conference on Computer Vision (ICCV)",
year = "2021"
}
License
See MiT License
Acknowledgement
We sincerely thank all of the following great repos: pytorch-soft-actor-critic, pytorch-REINFORCE, MLNet-Pytorch, and TASED-Net.