Awesome
Training neural networks with local error signals
This repo contains PyTorch code for training neural networks without global backprop. Experiments are performed by Arild Nøkland and Lars Hiller Eidnes.
A more detailed description of the experiments is available on arXiv here: https://arxiv.org/abs/1901.06656
Supervised training of neural networks for classification is typically performed with a global loss function. The loss function provides a gradient for the output layer, and this gradient is back-propagated to hidden layers to dictate an update direction for the weights. An alternative approach is to train the network with layer-wise loss functions. In this paper we demonstrate, for the first time, that layer-wise training can approach the state-of-the-art on a variety of image datasets. We use single-layer sub-networks and two different supervised loss functions to generate local error signals for the hidden layers, and we show that the combination of these losses help with optimization in the context of local learning. Using local errors could be a step towards more biologically plausible deep learning because the global error does not have to be transported back to hidden layers.
In the tables below, 'pred' indicates a layer-wise cross-entropy loss, 'sim' indicates a layer-wise similarity matching loss, and 'predsim' indicates a combination of these losses. For the local losses, the computational graph is detached after each hidden layer.
Experiments
Results on MNIST with 2 pixel jittering:
Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' |
---|---|---|---|---|---|
mlp | 2.9M | 0.75 | 0.68 | 0.80 | 0.62 |
vgg8b | 7.3M | 0.26 | 0.40 | 0.65 | 0.31 |
vgg8b + cutout | 7.3M | - | - | - | 0.26 |
Results on Fashion-MNIST with 2 pixel jittering and horizontal flipping:
Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' |
---|---|---|---|---|---|
mlp | 2.9M | 8.37 | 8.60 | 9.70 | 8.54 |
vgg8b | 7.3M | 4.53 | 5.66 | 5.12 | 4.65 |
vgg8b (2x) | 28.2M | 4.55 | 5.11 | 4.92 | 4.33 |
vgg8b (2x) + cutout | 28.2M | - | - | - | 4.14 |
Results on Kuzusjiji-MNIST with no data augmentation:
Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' |
---|---|---|---|---|---|
mlp | 2.9M | 5.99 | 7.26 | 9.80 | 7.33 |
vgg8b | 7.3M | 1.53 | 2.22 | 2.19 | 1.36 |
vgg8b + cutout | 7.3M | - | - | - | 0.99 |
Results on Cifar-10 with data augmentation:
Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' |
---|---|---|---|---|---|
mlp | 27.3M | 33.56 | 32.33 | 33.48 | 30.93 |
vgg8b | 8.9M | 5.99 | 8.40 | 7.16 | 5.58 |
vgg11b | 11.6M | 5.56 | 8.39 | 6.70 | 5.30 |
vgg11b (2x) | 42.0M | 4.91 | 7.30 | 6.66 | 4.42 |
vgg11b (3x) | 91.3M | 5.02 | 7.37 | 9.34 | 3.97 |
vgg11b (3x) + cutout | 91.3M | - | - | - | 3.60 |
Results on Cifar-100 with data augmentation:
Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' |
---|---|---|---|---|---|
mlp | 27.3M | 62.57 | 58.87 | 62.46 | 56.88 |
vgg8b | 9.0M | 26.24 | 29.32 | 32.64 | 24.07 |
vgg11b | 11.7M | 25.18 | 29.58 | 30.82 | 24.05 |
vgg11b (2x) | 42.1M | 23.44 | 26.91 | 28.03 | 21.20 |
vgg11b (3x) | 91.4M | 23.69 | 25.90 | 28.01 | 20.13 |
Results on SVHN with extra training data, but no augmentation:
Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' |
---|---|---|---|---|---|
vgg8b | 8.9M | 2.29 | 2.12 | 1.89 | 1.74 |
vgg8b + cutout | 8.9M | - | - | - | 1.65 |
Results on STL-10 with no data augmentation:
Network | #Params | Global loss | Local loss 'pred' | Local loss 'sim' | Local loss 'predsim' |
---|---|---|---|---|---|
vgg8b | 11.5M | 33.08 | 26.83 | 23.15 | 20.51 |
vgg8b + cutout | 11.5M | - | - | - | 19.25 |
Training recipes
To replicate training of MLP on MNIST with local loss 'predsim':
python train.py --model mlp --dataset MNIST --dropout 0.1 --lr 5e-4 --num-layers 3 --epochs 100 --lr-decay-milestones 50 75 89 94 --nonlin leakyrelu
To replicate training of VGG8b on MNIST with local loss 'predsim':
python train.py --model vgg8b --dataset MNIST --dropout 0.2 --lr 5e-4 --epochs 100 --lr-decay-milestones 50 75 89 94 --nonlin leakyrelu --dim-in-decoder 1024
To replicate training of MLP on CIFAR10 with local loss 'predsim':
python train.py --model mlp --dataset CIFAR10 --dropout 0.1 --lr 5e-4 --num-layers 3 --num-hidden 3000 --nonlin leakyrelu
To replicate training of VGG8b on CIFAR10 with local loss 'predsim':
python train.py --model vgg8b --dataset CIFAR10 --dropout 0.2 --lr 5e-4 --nonlin leakyrelu --dim-in-decoder 2048
To replicate training of VGG11b (3x) on CIFAR10 with local loss 'predsim':
python train.py --model vgg11b --dataset CIFAR10 --dropout 0.3 --lr 3e-4 --feat-mult 3 --nonlin leakyrelu
For all the above recipes, to train with local cross-entropy loss, add argument
--loss-sup pred
For all the above recipes, to train with local similarity matching loss, add argument
--loss-sup sim
For all the above recipes, to train with global loss, add argument
--backprop
For all the above recipes, to train with a more biologically plausible version of local loss, add argument
--bio
To add cutout regularization with cutout hole size 14, add arguments
--cutout --length 14
To replicate all the above experiments, run
./run_experiments.sh