Home

Awesome

Implementing Attention Augmented Convolutional Networks using Pytorch

Update (2019.05.11)

import torch

from attention_augmented_conv import AugmentedConv

use_cuda = torch.cuda.is_available()
device = torch.deivce('cuda' if use_cuda else 'cpu')

tmp = torch.randn((16, 3, 32, 32)).to(device)
augmented_conv1 = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4, relative=True, stride=1, shape=32).to(device)
conv_out1 = augmented_conv1(tmp)
print(conv_out1.shape) # (16, 20, 32, 32)

for name, param in augmented_conv1.named_parameters():
    print('parameter name: ', name)
import torch

from attention_augmented_conv import AugmentedConv

use_cuda = torch.cuda.is_available()
device = torch.deivce('cuda' if use_cuda else 'cpu')

tmp = torch.randn((16, 3, 32, 32)).to(device)
augmented_conv1 = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4, relative=True, stride=2, shape=16).to(device)
conv_out1 = augmented_conv1(tmp)
print(conv_out1.shape) # (16, 20, 16, 16)

Update (2019.05.02)

import torch

from attention_augmented_conv import AugmentedConv

use_cuda = torch.cuda.is_available()
device = torch.deivce('cuda' if use_cuda else 'cpu')

temp_input = torch.randn((16, 3, 32, 32)).to(device)
augmented_conv = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=1, relative=False, stride=1).to(device)
conv_out = augmented_conv(tmp)
print(conv_out.shape) # (16, 20, 32, 32), (batch_size, out_channels, height, width)
import torch

from attention_augmented_conv import AugmentedConv

use_cuda = torch.cuda.is_available()
device = torch.deivce('cuda' if use_cuda else 'cpu')

temp_input = torch.randn((16, 3, 32, 32)).to(device)
augmented_conv = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=1, relative=False, stride=2).to(device)
conv_out = augmented_conv(tmp)
print(conv_out.shape) # (16, 20, 16, 16), (batch_size, out_channels, height, width)
assert self.Nh != 0, "integer division or modulo by zero, Nh >= 1"
assert self.dk % self.Nh == 0, "dk should be divided by Nh. (example: out_channels: 20, dk: 40, Nh: 4)"
assert self.dv % self.Nh == 0, "dv should be divided by Nh. (example: out_channels: 20, dv: 4, Nh: 4)"
assert stride in [1, 2], str(stride) + " Up to 2 strides are allowed."

I posted two versions of the "Attention-Augmented Conv"

Reference

Paper

Wide-ResNet

Method

image

Input Parameters

Experiments

DatasetsModelAccuracyEpochTraining Time
CIFAR-10Wide-ResNet 28x10(WORK IN PROCESS)
CIFAR-100Wide-ResNet 28x10(WORK IN PROCESS)
CIFAR-100Just 3-Conv layers(channels: 64, 128, 192)61.6%10022m
CIFAR-100Just 3-Attention-Augmented Conv layers(channels: 64, 128, 192)59.82%352h 23m

Time complexity

Requirements