Home

Awesome

LambdaNetworks: Modeling long-range Interactions without Attention

Experimnets (CIFAR10)

ModelkhumParams (M)Acc (%)
ResNet18 baseline (ref)1493.02
LambdaResNet18164498.692.21 (70 Epochs)
LambdaResNet18164478.694.20 (67 Epochs)
LambdaResNet18164458.691.58 (70 Epochs)
LambdaResNet18164123891.36 (69 Epochs)
ResNet50 baseline (ref)23.593.62
LambdaResNet50164471393.74 (70 epochs)

Usage

import torch

from model import LambdaConv, LambdaResNet50, LambdaResNet152

x = torch.randn([2, 3, 32, 32])
conv = LambdaConv(3, 128)
print(conv(x).size()) # [2, 128, 32, 32]

# reference
# https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

model = LambdaResNet50()
print(get_n_params(model)) # 14.9M (Ours) / 15M(Paper)

model = LambdaResNet152()
print(get_n_params(model)) # 32.8M (Ours) / 35M (Paper)

Parameters

ModelkhumParams (M), PaperParams (M), Ours
LambdaResNet5016412315.014.9
LambdaResNet501644716.016.0
LambdaResNet1521641233532.8
LambdaResNet2001641234235.29

Ablation Parameters

khuParams (M), PaperParams (M), Ours
ResNet baseline25.625.5
82114.815.0
816115.614.9
24114.714.6
44114.714.66
84114.814.66
164115.014.99
324115.415.4
28114.714.5
48114.714.57
88114.714.74
168115.114.1
328115.715.76
88415.315.26
88816.016.0
164416.016.0