Home

Awesome

Pruning of Deep Spiking Neural Networks through Gradient Rewiring

This directory contains the code of this paper. The pretrained model is too large to fit <50MB requirements for supplementary file. Nonetheless, we spare no effort to maintain the reproducibility by keeping the random seeds in our experiment and clarifying the dependency and environment.

Directory Tree

.
├── c10
│   ├── c10.py
│   ├── __init__.py
│   └── model.py
├── deeprewire.py
├── gradrewire.py
├── mnist
│   ├── __init__.py
│   ├── mnist.py
│   └── model.py
└── README.md

The training (including test) code and model definition for CIFAR-10 and MNIST are located on corresponding two separate directory (c10 and mnist). The proposed Grad Rewiring algorithm is integrated with Adam optimizer in file gradrewire.py as a PyTorch optimizer. The code of Deep Rewiring algorithm (Deep R) is organized in the same way.

Dependency

The major dependencies of this repo are list as below

# Name                    Version
cudatoolkit               10.1.243
cudnn                     7.6.5
numpy                     1.19.1
python                    3.7.9 
pytorch                   1.6.0
spikingjelly              <Specific Version>
tensorboard               2.2.1
torchvision               0.7.0

Note: the version of spikingjelly will be clarified in usage part.

Environment

The code requires NVIDIA GPU and has been tested on CUDA 10.1 and Ubuntu 16.04. You may need GPU with >6GB video memory to get the code run as the same batch size in our paper to reproduce the results.

We use a single Tesla V100 GPU for each experiment. We recommend GPU with ECC enabled if you want exactly the same results (e.g. the training curves shown in paper).

Epoch Time (Wall Clock Time)

The rough running time here is measured on platforms mentioned above and should only be regarded as a reference.

DatasetTrain & Test (s)Train Only (s)
CIFAR-101150540
MNIST12.812.3

There are many intricate data processing in the test stage, consuming much time.

Usage

This code requires a legacy version of an open-source SNN framework SpikingJelly. To get this framework installed, first clone the repo from GitHub:

$ git clone https://github.com/fangwei123456/spikingjelly.git

or OpenI:

$ git clone https://git.openi.org.cn/OpenI/spikingjelly.git

Then, checkout the version we use in these experiments and install it.

$ cd spikingjelly
$ git checkout c8a9ba8
$ python setup.py install

With dependency mentioned above installed, you should be able to run the following commands:

Grad Rewiring on CIFAR-10:

$ cd <repo_path>/c10
$ python c10.py -s 0.95 -gpu <gpu_id> --dataset-dir <dataset_path> --dump-dir <dump_logs&models_path> -m grad

Grad Rewiring on MNIST:

$ cd <repo_path>/mnist
$ python mnist.py -s 0.95 -gpu <gpu_id> --dataset-dir <dataset_path> --dump-dir <dump_logs&models_path> -m grad

The TensorBoard logs will be placed in <dump-dir>/logs.

Running Arguments

ArgumentsDescriptionsDefault ValueType
-b,--batch-sizeTraining batch size128(MNIST),16(CIFAR-10)int
-lr,--learning-rateLearning rate1e-4float
-penaltyL1 penalty for Deep R, prior term for Grad Rewiring1e-3float
-s,--sparsityMaximum sparsity for Deep R, target sparsity for soft-Deep R and Grad Rewiringfloat
-gpuGPU idstr
--dataset-dirPath of datasetsstr
--dump-dirPath for dumping models and logsstr
-TSimulation time-steps8int
-N,--epochNumber of training epochs512(MNIST),2048(CIFAR-10)int
-m,--modePruning method ('deep' or 'grad', or 'no_prune')'no_prune'str
-softWhether to use soft Deep R (Only work when mode='deep')Falsebool
-testWhether to test onlyFalsebool

Citation

Please refer to the following citation if this work is useful for your research.

@inproceedings{ijcai2021-236,
  title     = {Pruning of Deep Spiking Neural Networks through Gradient Rewiring},
  author    = {Chen, Yanqi and Yu, Zhaofei and Fang, Wei and Huang, Tiejun and Tian, Yonghong},
  booktitle = {Proceedings of the Thirtieth International Joint Conference on
               Artificial Intelligence, {IJCAI-21}},
  publisher = {International Joint Conferences on Artificial Intelligence Organization},
  editor    = {Zhi-Hua Zhou},
  pages     = {1713--1721},
  year      = {2021},
  month     = {8},
  note      = {Main Track}
  doi       = {10.24963/ijcai.2021/236},
  url       = {https://doi.org/10.24963/ijcai.2021/236},
}