Home

Awesome

<b>Antialiased CNNs</b> [Project Page] [Paper] [Talk]

<img src='https://richzhang.github.io/antialiased-cnns/resources/gifs2/video_00810.gif' align="right" width=300>

Making Convolutional Networks Shift-Invariant Again <br> Richard Zhang. In ICML, 2019.

Quick & easy start

Run pip install antialiased-cnns

import antialiased_cnns
model = antialiased_cnns.resnet50(pretrained=True) 
<!-- model.load_state_dict(torch.load('resnet50_lpf4-994b528f.pth.tar')['state_dict']) # load weights; download it beforehand from https://www.dropbox.com/s/zqsudi0oz5ym8w8/resnet50_lpf4-994b528f.pth.tar?dl=0 --> <!-- Now you are antialiased! -->

If you have a model already and want to antialias and continue training, copy your old weights over:

import torchvision.models as models
old_model = models.resnet50(pretrained=True) # old (aliased) model
antialiased_cnns.copy_params_buffers(old_model, model) # copy the weights over

If you want to modify your own model, use the BlurPool layer. More information about our provided models and how to use BlurPool is below.

C = 10 # example feature channel size
blurpool = antialiased_cnns.BlurPool(C, stride=2) # BlurPool layer; use to downsample a feature map
ex_tens = torch.Tensor(1,C,128,128)
print(blurpool(ex_tens).shape) # 1xCx64x64 tensor

Updates

Table of contents

  1. More information about antialiased models<br>
  2. Instructions for antialiasing your own model, using the BlurPool layer<br>
  3. ImageNet training and evaluation code. Achieving better consistency, while maintaining or improving accuracy, is an open problem. Help improve the results!

(0) Preliminaries

Pip install this package

Or clone this repository and install requirements (notably, PyTorch)


https://github.com/adobe/antialiased-cnns.git
cd antialiased-cnns
pip install -r requirements.txt

(1) Loading an antialiased model

The following loads a pretrained antialiased model, perhaps as a backbone for your application.

import antialiased_cnns
model = antialiased_cnns.resnet50(pretrained=True, filter_size=4)

We also provide weights for antialiased AlexNet, VGG16(bn), Resnet18,34,50,101, Densenet121, and MobileNetv2 (see example_usage.py).

(2) How to antialias your own architecture

The antialiased_cnns module contains the BlurPool class, which does blur+subsampling. Run pip install antialiased-cnns or copy the antialiased_cnns subdirectory.

Methodology The methodology is simple -- first evaluate with stride 1, and then use our BlurPool layer to do antialiased downsampling. Make the following architectural changes.

import antialiased_cnns

# MaxPool --> MaxBlurPool
baseline = nn.MaxPool2d(kernel_size=2, stride=2)
antialiased = [nn.MaxPool2d(kernel_size=2, stride=1), 
    antialiased_cnns.BlurPool(C, stride=2)]
    
# Conv --> ConvBlurPool
baseline = [nn.Conv2d(Cin, C, kernel_size=3, stride=2, padding=1), 
    nn.ReLU(inplace=True)]
antialiased = [nn.Conv2d(Cin, C, kernel_size=3, stride=1, padding=1),
    nn.ReLU(inplace=True),
    antialiased_cnns.BlurPool(C, stride=2)]

# AvgPool --> BlurPool
baseline = nn.AvgPool2d(kernel_size=2, stride=2)
antialiased = antialiased_cnns.BlurPool(C, stride=2)

We assume incoming tensor has C channels. Computing a layer at stride 1 instead of stride 2 adds memory and run-time. As such, we typically skip antialiasing at the highest-resolution (early in the network), to prevent large increases.

Add antialiasing and then continue training If you already trained a model, and then add antialiasing, you can fine-tune from that old model:

antialiased_cnns.copy_params_buffers(old_model, antialiased_model)

If this doesn't work, you can just copy the parameters (and not buffers). Adding antialiasing doesn't add any parameters, so the parameter lists are identical. (It does add buffers, so some heuristic is used to match the buffers, which may throw an error.)

antialiased_cnns.copy_params(old_model, antialiased_model)

<img src='https://richzhang.github.io/antialiased-cnns/resources/antialias_mod.jpg' width=800><br>

(3) ImageNet Evaluation, Results, and Training code

We observe improvements in both accuracy (how often the image is classified correctly) and consistency (how often two shifts of the same image are classified the same).

<img src='plots/plots2_acc.png' align="left" width=750> <img src='plots/plots2_con.png' align="left" width=750>
ACCURACYBaselineAntialiasedDelta
alexnet56.5556.94+0.39
vgg1169.0270.51+1.49
vgg1369.9371.52+1.59
vgg1671.5972.96+1.37
vgg1972.3873.54+1.16
vgg11_bn70.3872.63+2.25
vgg13_bn71.5573.61+2.06
vgg16_bn73.3675.13+1.77
vgg19_bn74.2475.68+1.44
resnet1869.7471.67+1.93
resnet3473.3074.60+1.30
resnet5076.1677.41+1.25
resnet10177.3778.38+1.01
resnet15278.3179.07+0.76
resnext50_32x4d77.6277.93+0.31
resnext101_32x8d79.3179.33+0.02
wide_resnet50_278.4778.70+0.23
wide_resnet101_278.8578.99+0.14
densenet12174.4375.79+1.36
densenet16975.6076.73+1.13
densenet20176.9077.31+0.41
densenet16177.1477.88+0.74
mobilenet_v271.8872.72+0.84
CONSISTENCYBaselineAntialiasedDelta
alexnet78.1883.31+5.13
vgg1186.5890.09+3.51
vgg1386.9290.31+3.39
vgg1688.5290.91+2.39
vgg1989.1791.08+1.91
vgg11_bn87.1690.67+3.51
vgg13_bn88.0391.09+3.06
vgg16_bn89.2491.58+2.34
vgg19_bn89.5991.60+2.01
resnet1885.1188.36+3.25
resnet3487.5689.77+2.21
resnet5089.2091.32+2.12
resnet10189.8191.97+2.16
resnet15290.9292.42+1.50
resnext50_32x4d90.1791.48+1.31
resnext101_32x8d91.3392.67+1.34
wide_resnet50_290.7792.46+1.69
wide_resnet101_290.9392.10+1.17
densenet12188.8190.35+1.54
densenet16989.6890.61+0.93
densenet20190.3691.32+0.96
densenet16190.8291.66+0.84
mobilenet_v286.5087.73+1.23

To reduce clutter, extended results (different filter sizes) are here. Help improve the results!

Licenses

<a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc-sa/4.0/80x15.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/">Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License</a>.

All material is made available under Creative Commons BY-NC-SA 4.0 license by Adobe Inc. You can use, redistribute, and adapt the material for non-commercial purposes, as long as you give appropriate credit by citing our paper and indicating any changes that you've made.

The repository builds off the PyTorch examples repository and torchvision models repository. These are BSD-style licensed.

Citation, contact

If you find this useful for your research, please consider citing this bibtex. Please contact Richard Zhang <rizhang at adobe dot com> with any comments or feedback.