Home

Awesome

Parametric-Leaky-Integrate-and-Fire-Spiking-Neuron

中文README

This repository contains the origin codes and TensorBoard logs for the paper Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks. The trained models are too large that we can't upload them to this repository. But we used a identical seed during training, and we can ensure that the user can get almost the same accuracy when using our codes to train.

Accuracy

This table shows the accuracy of using PLIF neurons, tau_0=2 and max pooling:

MNISTFashion-MNISTCIFAR10N-MNISTCIFAR10-DVSDVS128 Gesture
accuracy-A97.72%94.38%93.50%99.61%74.80%97.57%
accuracy-B99.63%93.85%92.58%99.57%69.00%96.53%

This table shows the accuracy-A of using PLIF/LIF neurons, different tau/tau_0 and average/max pooling:

poolingMNISTFashion-MNISTCIFAR-10N-MNISTCIFAR10-DVSDVS128 Gesture
PLIF,tau_0=2max99.72%94.38%93.5%99.61%74.8%97.57%
PLIF,tau_0=16max99.73%94.65%93.23%99.53%70.5%92.01%
LIF,tau=2max99.69%94.17%93.03%99.64%73.6%96.88%
LIF,tau=16max99.49%94.47%47.5%99.15%62.4%76.74%
PLIF,tau_0=2avg99.71%94.74%93.3%99.66%72.7%97.22%

Directory Structure

codes contains the origin codes:

models.py defines the networks.

train.py trains models on the training set, tests on the test set alternately, and records the maximum test accuracy, which is the accuracy-A in the paper.

train_val.py splits the origin training set into a new training set and validation set, trains on the new training set, tests on the validation set alternately, and records the test accuracy on the test set only once, with the model achieving the maximum validation accuracy, which is the accuracy-B in the paper.

logs contains A and B directories, which contains TensorBoard logs for different accuracies, respectively.

Dependency

The origin codes uses the old version SpikingJelly. To maximize reproducibility, the user can download the latest SpikingJelly and rollback to the version that we used to train:

git clone https://github.com/fangwei123456/spikingjelly.git
cd spikingjelly
git reset --hard 73f94ab983d0167623015537f7d4460b064cfca1
python setup.py install

Here is the commit information:

commit 73f94ab983d0167623015537f7d4460b064cfca1
Author: fangwei123456 <fangwei123456@pku.edu.cn>
Date:   Wed Sep 30 16:42:25 2020 +0800

    增加detach reset的选项

Datasets

The line 64 of train.py, and line 84 of train_val.py defines the dataset path:

dataset_dir = '/userhome/datasets/' + dataset_name

where /userhome/datasets/ is the root path of all datasets.

The root path of all datasets should have the following directory structure:

|-- CIFAR10
|   |-- cifar-10-batches-py
|   `-- cifar-10-python.tar.gz
|-- CIFAR10DVS
|   |-- airplane.zip
|   |-- automobile.zip
|   |-- bird.zip
|   |-- cat.zip
|   |-- deer.zip
|   |-- dog.zip
|   |-- events
|   |-- frames_num_20_split_by_number_normalization_None
|   |-- frog.zip
|   |-- horse.zip
|   |-- ship.zip
|   `-- truck.zip
|-- DVS128Gesture
|   |-- DvsGesture.tar.gz
|   |-- LICENSE.txt
|   |-- README.txt
|   |-- events_npy
|   |-- extracted
|   |-- frames_num_20_split_by_number_normalization_None
|   `-- gesture_mapping.csv
|-- FashionMNIST
|   |-- FashionMNIST
|-- MNIST
|   `-- MNIST
`-- NMNIST
    |-- Test.zip
    |-- Train.zip
    |-- events
    `-- frames_num_10_split_by_number_normalization_None

MNIST, Fashion-MNIST and CIFAR10 dataset can be available from torchvision. For neuromorphic datasets' installation, see

https://spikingjelly.readthedocs.io/zh_CN/0.0.0.0.4/spikingjelly.datasets.html

Running codes

Here are the origin running codes for accuracy-B:

DatasetRunning codes
MNISTpython ./codes/train_val.py -init_tau 2.0 -use_max_pool -use_plif -device cuda:0 -dataset_name MNIST -log_dir_prefix /userhome/plif_test/logsd -T 8 -max_epoch 1024 -detach_reset
Fashion-MNISTpython ./codes/train_val.py -init_tau 2.0 -use_max_pool -use_plif -device cuda:0 -dataset_name FashionMNIST -log_dir_prefix /userhome/plif_test/logsd -T 8 -max_epoch 1024 -detach_reset
CIFAR10python ./codes/train_val.py -init_tau 2.0 -use_max_pool -use_plif -device cuda:0 -dataset_name CIFAR10 -log_dir_prefix /userhome/plif_test/logsd -T 8 -max_epoch 1024 -detach_reset
N-MNISTpython ./codes/train_val.py -init_tau 2.0 -use_max_pool -device cuda:0 -dataset_name NMNIST -log_dir_prefix /userhome/plif_test/logsd -T 10 -max_epoch 1024 -detach_reset -channels 128 -number_layer 2 -split_by number -normalization None -use_plif
CIFAR10-DVSpython ./codes/train_val.py -init_tau 2.0 -use_max_pool -device cuda:0 -dataset_name CIFAR10DVS -log_dir_prefix /userhome/plif_test/logsd -T 20 -max_epoch 1024 -detach_reset -channels 128 -number_layer 4 -split_by number -normalization None -use_plif
DVS128 Gesturepython ./codes/train_val.py -init_tau 2.0 -use_max_pool -device cuda:0 -dataset_name DVS128Gesture -log_dir_prefix /userhome/plif_test/logsd -T 20 -max_epoch 1024 -detach_reset -channels 128 -number_layer 5 -split_by number -normalization None -use_plif

The code can recovery training from the interruption. It will load the exist model and continue training from the last epoch.

Arguments Definition

This table shows the definition of all arguments:

argumentmeaningtypedefault
init_tautau of all LIF neurons, or tau_0 of PLIF neuronsfloat-
batch_sizetraining batch sizeint16
learning_ratelearning ratefloat1e-3
T_maxperiod of the learning rate scheduleint64
use_plifuse PLIF neuronsaction='store_true'False
alpha_learnableif given, alpha in the surrogate function is learnableaction='store_true'False
use_max_poolif given, the network will use max pooling, else use average poolingaction='store_true'False
deviceuse which device to trainstr-
dataset_nameuse which datasetstr(MNIST,FashionMNIST,CIFAR10,NMNIST,CIFAR10DVSorDVSGesture)-
log_dir_prefixthe path for TensorBoard to save logsstr-
Tsimulating time-stepint-
channelsthe out channels of Conv2d for neuromorphic datasetsint-
number_layerthe number of Conv2d layers for neuromorphic datasetsint-
split_byhow to split the events to integrate them to framesstr(time ornumber )-
normalizationnormalization for frames during being integratedstr(frequency,max,norm,sum orNone)-
max_epochmaximum training epochint-
detach_resetwhether detach the voltage reset during backwardaction='store_true'False

For more details about split_bynormalization, see

https://spikingjelly.readthedocs.io/zh_CN/0.0.0.0.4/spikingjelly.datasets.html#integrate-events-to-frames-init-en

New Implement

SpkingJelly (0.0.0.0.12 or the latest version) has added the network with LIF/max-pooling as an example:

0.0.0.0.12: https://spikingjelly.readthedocs.io/zh_CN/0.0.0.0.12/clock_driven_en/14_classify_dvsg.html

latest: https://spikingjelly.readthedocs.io/zh_CN/latest/activation_based_en/classify_dvsg.html

The codes are written by the new version of SpikingJelly, which are faster than codes in this repository.

All networks in this paper are available at SpikingJelly:

0.0.0.0.12: https://github.com/fangwei123456/spikingjelly/blob/0.0.0.0.12/spikingjelly/clock_driven/model/parametric_lif_net.py

latest: https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/model/sew_resnet.py

Cite

@InProceedings{Fang_2021_ICCV,
    author    = {Fang, Wei and Yu, Zhaofei and Chen, Yanqi and Masquelier, Timothee and Huang, Tiejun and Tian, Yonghong},
    title     = {Incorporating Learnable Membrane Time Constant To Enhance Learning of Spiking Neural Networks},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {2661-2671}
}