Home

Awesome

PyTorch-Deformable-Convolution-v2

Don't feel pain to use Deformable Convolution v2(DCNv2)

If you are curious about how to visualize offset(red point), refer to offset_visualization.py

Usage


from dcn import DeformableConv2d

class Model(nn.Module):
    ...
    self.conv = DeformableConv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)
    ...

Experiment

You can simply reproduce the results of my experiment on Google Colab.

Refer to experiment.ipynb!

Task

Scaled-MNIST Handwritten Digit Classification

Model

Simple CNN Model including 5 conv layers

class MNISTClassifier(nn.Module):
    def __init__(self,
                 deformable=False):

        super(MNISTClassifier, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=True)   
        conv = nn.Conv2d if deformable==False else DeformableConv2d
        self.conv4 = conv(32, 32, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv5 = conv(32, 32, kernel_size=3, stride=1, padding=1, bias=True)
        
        self.pool = nn.MaxPool2d(2)
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(32, 10)
        
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(x) # [14, 14]
        x = torch.relu(self.conv2(x))
        x = self.pool(x) # [7, 7]
        x = torch.relu(self.conv3(x))
        x = torch.relu(self.conv4(x))
        x = torch.relu(self.conv5(x))
        x = self.gap(x)
        x = x.flatten(start_dim=1)
        x = self.fc(x)
        return x

Training

Test

In the paper, authors mentioned that the network's ability to model geometric transformation with DCNv2 is considerably enhanced.

I verified it with scale augmentation.

All images in the test set of MNIST dataset are augmented by scale augmentation(x0.5, x0.6, ..., x1.4, x1.5).

Results

ModelTop-1 Accuracy(%)
w/o DCNv290.03%
w/ DCNv292.90%

References

mxnet implementation

To Do Lists