Home

Awesome

Deep Networks on classification tasks using Torch

This is a complete training example for BinaryNets using Binary-Backpropagation algorithm as explained in "Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1, Matthieu Courbariaux, Itay Hubara, Daniel Soudry, Ran El-Yaniv, Yoshua Bengio' on following datasets: Cifar10/100, SVHN, MNIST

Data

We use dp library to extract all the data please view installation section

Dependencies

To install all dependencies (assuming torch is installed) use:

luarocks install https://raw.githubusercontent.com/eladhoffer/DataProvider.torch/master/dataprovider-scm-1.rockspec
luarocks install cudnn
luarocks install dp
luarocks install unsup

Training

Create pre-processing folder:

cd BinaryNet
mkdir PreProcData

Start training using:

th Main_BinaryNet_Cifar10.lua -network BinaryNet_Cifar10_Model

or,

th Main_BinaryNet_MNIST.lua -network BinaryNet_MNIST_Model

Run with Docker

The Docker is built from nvidia/cuda:8.0-cudnn5-devel with Torch commit 0219027e6c4644a0ba5c5bf137c989a0a8c9e01b

Additional flags

FlagDefault ValueDescription
modelsFolder./Models/Models Folder
networkModel.luaModel file - must return valid network.
LR0.1learning rate
LRDecay0learning rate decay (in # samples
weightDecay1e-4L2 penalty on the weights
momentum0.9momentum
batchSize128batch size
stcNeuronstrueusing stochastic binarization for the neurons or not
stcWeightsfalseusing stochastic binarization for the weights or not
optimizationadamoptimization method
SBNtrueuse shift based batch-normalization or not
runningValtrueuse running mean and std or not
epoch-1number of epochs to train (-1 for unbounded)
threads8number of threads
typecudafloat or cuda
devid1device ID (if using CUDA)
loadnoneload existing net weights
savetime-identifiersave directory
datasetCifar10Dataset - Cifar10, Cifar100, STL10, SVHN, MNIST
dp_preprofalsepreprocessing using dp lib
whitenfalsewhiten data
augmentfalseAugment training data
preProcDir./PreProcData/Data for pre-processing (means,Pinv,P)