Home

Awesome

pytorch-randaugment

Unofficial PyTorch Reimplementation of RandAugment. Most of codes are from Fast AutoAugment.

Introduction

Models can be trained with RandAugment for the dataset of interest with no need for a separate proxy task. By only tuning two hyperparameters(N, M), you can achieve competitive performances as AutoAugments.

Install

$ pip install git+https://github.com/ildoonet/pytorch-randaugment

Usage

from torchvision.transforms import transforms
from RandAugment import RandAugment

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
])

# Add RandAugment with N, M(hyperparameter)
transform_train.transforms.insert(0, RandAugment(N, M))

Experiment

We use same hyperparameters as the paper mentioned. We observed similar results as reported.

You can run an experiment with,

$ python RandAugment/train.py -c confs/wresnet28x10_cifar10_b256.yaml --save cifar10_wres28x10.pth

CIFAR-10 Classification

ModelPaper's ResultOurs
Wide-ResNet 28x1097.397.4
Shake26 2x96d98.098.1
Pyramid27298.5

CIFAR-100 Classification

ModelPaper's ResultOurs
Wide-ResNet 28x1083.383.3

SVHN Classification

ModelPaper's ResultOurs
Wide-ResNet 28x1098.998.8

ImageNet Classification

I have experienced some difficulties while reproducing paper's result.

Issue : https://github.com/ildoonet/pytorch-randaugment/issues/9

ModelPaper's ResultOurs
ResNet-5077.6 / 92.8TODO
EfficientNet-B583.2 / 96.7TODO
EfficientNet-B784.4 / 97.1TODO

References