Home

Awesome

ST-3: Are Straight-Through gradients and Soft-Thresholding all you need for Sparse Training?

Source Code for accepted paper at the upcoming WACV 2023 conference (paper here)

General Overview

Pruning litterature has turned to more and more complex methods to prune weights during training based on (ProbMask), sometimes even taking cues from biological neurogeneration (GraNet). This work aims at taking a simpler approach (that nevertheless surpasses previous SoA) based on minimizing the mismatch between forward and backward propagation that occurs when a Straight-through-estimator is used to update the weights. To reduce this disparity, soft thresholding and weight rescaling are applied during forward propagation only and the pruning ratio is cubicly increased during training to allow for a smoother transition.

Results on ImageNet

These are the result of training on ImageNet for 100 epochs (no longer), w/ only RandomCropping and RandomFlipping aas data augmentation during training as to align with results in the literature. ST-3 uses l1-magnitude pruning, ST-3 $^\sigma$ uses scaled l1-magnitude pruning as to force the pruning to be more uniform accross layers. ST-3 $^\sigma$ tends do have a better flops reduction although it tends to produce results that are slightly worse than ST-3 w/o constraints.

ResNet-50

MethodAccuracy [%]Sparsity [%]GFLOPS
Baseline77.1004089
ST-376.95801215
ST-3 $^\sigma$76.4480739
ST-376.0390764
ST-3 $^\sigma$75.2890397
ST-374.4695436
ST-3 $^\sigma$73.6995219
ST-373.3196.5351
ST-3 $^\sigma$72.6296.5167
ST-370.4698220
ST-3 $^\sigma$69.7598116
ST-363.8899120
ST-3 $^\sigma$63.259969

Instructions

The pieces of code proper to the ST-3(sigma) method described in the paper are available in the following 2 files:

In those files a lot of code is also there to reproduce results of ProbMask, GMP and LRR and can be ignored, they are included for completeness. The other files are mostly boilerplate and part of a bigger framework that isn't necessary to understand ST-3.

Execution

$ ./launch <path/to/configs>

Weights

The weights will be made available soon