Home

Awesome

Revisiting Batch Norm Initialization

This repo contains the official code for the paper "Revisiting Batch Norm Initialization" by Jim Davis and Logan Frank, which was accepted to the European Conference on Computer Vision (ECCV) 2022.

In this work, we observed that the learned scale (γ) and shift (β) affine transformation parameters of batch normalization tend to remain close to their initialization and further noticed that the normalization operation of batch normalization can yield overly large values, which are preserved through the remainder of the forward pass due to the previous observation. We first examined the batch normalization gradient equations and derived the influence of the batch normalization scale parameter with respect to training then empirically showed across multiple datasets and network architectures that with initializations of the BN scale parameter < 1 and reducing the learning rate on the batch normalization scale parameter, statistically significant gains in performance can be seen (according to a one-sided paired t-test).

Overview

The contents of this repo are organized as follows:

Requirements

Assuming you have already created an environment with Python 3.8 and pip, install the necessary package requirements in requirements.txt using pip install -r requirements.txt. The main requirements are

with specific versions given in requirements.txt.

Training

An example for running our training algorithm using everything proposed in the paper is:

python train.py \
    --path '' \
    --name 'train_example' \
    --dataset 'cifar10' \
    --network 'resnet18' \
    --batch_size 128 \
    --num_epochs 180 \
    --learning_rate 0.1 \
    --scheduler 'cos' \
    --momentum 0.9 \
    --weight_decay 0.0001 \
    --device 'cuda:0' \
    --bn_weight 0.1 \
    --c 100 \
    --input_norm 'bn' \
    --seed '1' 

where

The above command-line arguments are general arguments for training a CNN. The next four command-line arguments are specific to our work where

Batch Normalization

To instantiate a single batch normalization layer using the ScaleBatchNorm2d class in batch_norm.py, call

bn1 = ScaleBatchNorm2d(num_feature=64, eps=1e-5, momentum=0.1, affine=True, bn_weight=0.1)

which creates a batch normalization layer that takes 64 feature maps as input and initializes the scale parameter to a value of 0.1.

In many cases, a partial function may be useful (e.g., when calling the torchvision constructors for ResNet, etc.). An example of creating a partial function then using that partial function is

norm_layer = partial(ScaleBatchNorm2d, eps=1e-5, momentum=0.1, affine=True, bn_weight=0.1)
network = torchvision.models.resnet18(num_classes=num_classes, norm_layer=norm_layer)

which creates a ResNet18 network from the torchvision library that utilizes our proposed ScaleBatchNorm2d.

Network

To create a network using our construct_network function in networks.py, call

network = construct_network(network_name='resnet18', num_classes=10, dataset='cifar10', bn_weight=0.1, input_norm='bn')

The above function call will instantiate a ResNet18 network with 10 output classes, is modified to account for the smaller imagery of CIFAR10, initializes all batch normalization layers to have an initial scale value of 0.1, and utilizes our proposed batch normalization-based input normalization scheme. network_name can be any of the base ResNet architectures (18, 34, 50, 101, 152) following a similar string value as provided, bn_weight can be any value >0 (though our work proposes setting this value <1), and input_norm can be bn to employ our proposed input normalization or dataset to employ the precomputed global dataset statistics for CIFAR10.

Reducing Learning Rate

Given a network has been instantiated, the learning rate for the batch normalization scale parameters can be reduced and provided to an optimizer using the following example. Note this example will also properly apply weight decay to only the convolutional and fully-connected weights. All network biases and batch normalization parameters will have weight decay == 0.

parameters = adjust_weight_decay_and_learning_rate(network, weight_decay=1e-4, learning_rate=0.1, c=100)
optimizer = optim.SGD(parameters, lr=0.1, momentum=0.9)

Setting Seeds

Setting the seeds for important random number generators using our functions provided in seeds.py is easily done by calling

make_deterministic('1')

which will take the value of 1, make it more complex using MD5, then seed various random number generators using that complex value.

Evaluating Significance

Determining whether improvements are significant is crucial, this can be done by calling the evaluate function in t_test.py. For example,

my_cool_new_method = np.array([91.7, 93.4, 92.2, 90.0, 91.9])
baseline = np.array([91.0, 92.7, 92.2, 90.1, 91.5])

p_value = evaluate(my_cool_new_method, baseline)

if p_value <= 0.05:
    print('My approach is significantly greater than the baseline!')

Citation

Please cite our paper "Revisiting Batch Norm Initialization" with

@article{Davis2022revisiting,
  title={Revisiting Batch Norm Initialization},
  author={Davis, Jim and Frank, Logan},
  journal={European Conference on Computer Vision},
  year={2022}
}