Awesome
Rate-based Backpropagation
This repository hosts the code implementation of the rate-based backpropagation detailed in paper, "Advancing Training Efficiency of Deep Spiking Neural Networks through Rate-based Backpropagation", accepted at NeurIPS 2024. [arXiv][OpenReview]
<img src="doc/figure/fig1.png" alt="introduction_figure" style="zoom:100%;" />Dependencies
# Name Version
python 3.9.19
torch 2.3.1
torchvision 0.18.1
tensorboard 2.17.0
spikingjelly 0.0.0.0.14
Directory Tree
.
├── experiment
│ ├── cifar
│ │ ├── config
│ │ └── main.py
│ └── dvs
│ ├── config
│ └── main.py
├── model
│ ├── layer.py
│ ├── resnet.py
│ └── vgg.py
└── util
├── data.py
├── image_augment.py
├── misc.py
└── util.py
The experiment code for datasets are located on corresponding directories in experiment (CIFAR-10/CIFAR-100 in experiment/cifar
, CIFAR10-DVS in experiment/dvs
).
Code related to neurons is defined in model/layer.py
; code related to batch normalization is defined in util/util.py
.
The computational graph using rate-based gradients is implemented via model hooks, which are encapsulated in util/util.py
.
Usage
-
Reproducing BPTT Results. To reproduce the results using backpropagation-through-time (BPTT) within the current project framework, use the following commands:
##### BPTT with multi-step ##### # for CIFAR-10/100 python experiment/cifar/main.py --dataset CIFAR10 --data_path [data_path] --arch resnet18 --T 4 --step_mode m # for CIFAR10-DVS python experiment/dvs/main.py --dataset CIFAR10_DVS_Aug --data_path [data_path] --arch vggsnn_dvs --step_mode m ##### BPTT with single-step ##### # for CIFAR-10/100 python experiment/cifar/main.py --dataset CIFAR10 --data_path [data_path] --arch resnet18 --T 4 --step_mode s # for CIFAR10-DVS python experiment/dvs/main.py --dataset CIFAR10_DVS_Aug --data_path [data_path] --arch vggsnn_dvs --step_mode s
-
Using Rate-Based Backpropagation. To enable rate-based backpropagation as a replacement for BPTT, use the following commands:
##### Rate-BP with multi-step ##### # for CIFAR-10/100 python experiment/cifar/main.py --dataset CIFAR10 --data_path [data_path] --arch resnet18 --T 4 --step_mode m --rate_flag # for CIFAR10-DVS python experiment/dvs/main.py --dataset CIFAR10_DVS_Aug --data_path [data_path] --arch vggsnn_dvs --step_mode m --rate_flag ##### Rate-BP with single-step ##### # for CIFAR-10/100 python experiment/cifar/main.py --dataset CIFAR10 --data_path [data_path] --arch resnet18 --T 4 --step_mode s --rate_flag # for CIFAR10-DVS python experiment/dvs/main.py --dataset CIFAR10_DVS_Aug --data_path [data_path] --arch vggsnn_dvs --step_mode s --rate_flag
-
Options for Hyper-Parameters:
--arch
: corresponding SNN models, supporting: resnet18, resnet19, vggsnn_cifar, vggsnn_dvs--T
: Specifies the number of timesteps for the SNN model. Fixed as 10 on CIFAR10-DVS.--step_mode
: Specifies the training mode.- m: Multi-step training mode, where T loops are embedded within layers.
- s: Single-step training mode, where T loops are executed outside the layers.
--rate_flag
: Indicates whether to use the rate-based backpropagation (instead of BPTT).- Include this flag to enable Rate-BP.
- Omit this flag to use standard backpropagation (BPTT).
Citation
If you find this work useful for your research, please consider citing it as follows:
@inproceedings{yu2024advancing,
title={Advancing Training Efficiency of Deep Spiking Neural Networks through Rate-based Backpropagation},
author={Yu, Chengting and Liu, Lei and Wang, Gaoang and Li, Erping and Wang, Aili},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}
}
Acknowledgement
Some of our code implementations for the SNN models and data preprocessing are based on references and adaptations from repositories: SpikingJelly, OTTT, SLTT, ASGL.