Home

Awesome

Spatial Transformer Networks

Reimplementations of:

Although implementations already exists, this focuses on simplicity and ease of understanding of the vision transforms and model.

Results

During training, random homography perturbations are applied to each image in the minibatch. The perturbations are composed by component transformation (rotation, translation, shear, projection), the parameters of each sampled from a uniform(-1,1) * 0.25 multiplicative factor.

Example homography perturbation:<br> <img src='images/transform_test.png' alt='example perturbation' width=150>

Test set accuracy:

ModelAccuracyTraining params
Basic affine STN91.59%10 epochs at learning rate 1e-3 (classifier and transformer)
Homography STN93.30%10 epochs at learning rate 1e-3 (classifier and transformer)
Homography ICSTN97.67%10 epochs at learning rate 1e-3 (classifier) and 5e-4 (transformer)

Sample alignment results:

Basic affine STN

ImageSamples
original <br> perturbed <br/> transformedbasic

Homography STN

ImageSamples
original <br> perturbed <br> transformedstn

Homography ICSTN

ImageSamples
original <br> perturbed <br> transformedicstn

Mean and variance of the aligned results (cf Lin ICSTN paper)

Mean image

ImageBasic affine STNHomography STNHomography ICSTN
original <br> perturbed <br> transformedbasicstnicstn

Variance

ImageBasic affine STNHomography STNHomography ICSTN
original <br> perturbed <br> transformedbasicstnicstn

Usage

To train model:

python train.py --output_dir=[path to params.json]
                --restore_file=[path to .pt checkpoint if resuming training]
                --cuda=[cuda device id]

params.json provides training parameters and specifies which spatial transformer module to use:

  1. BasicSTNModule -- affine transform localization network
  2. STNModule -- homography transform localization network
  3. ICSTNModule -- homography transform localization netwokr (cf Lin, ICSTN paper)

To evaluate and visualize results:

python evaluate.py --output_dir=[path to params.json]
                   --restore_file=[path to .pt checkpoint]
                   --cuda=[cuda device id]

Dependencies

Useful resources