Home

Awesome

Improving Model Generalization by Agreement of Learned Representations from Data Augmentation

Paper

Why it matters?

When data augmentation is applied on an input image, a model is forced to learn invariant features to improve model generalization (Figure 1).

<img src="https://github.com/roatienza/agmax/blob/master/figures/fig1_agmax.png" width="50%" height="50%">

Since data augmentation incurs little overhead, why not generate 2 data augmented images (also known as 2 positive samples) from a given input. Then, force the model to agree on the common invariant features to support the correct label (Figure 2). It turns out that maximizing this agreement further improves model model generalization. We call our method AgMax.

<img src="https://github.com/roatienza/agmax/blob/master/figures/fig2_agmax.png" width="50%" height="50%">

Unlike label smoothing, AgMax consistently improves model accuracy. For example on ImageNet1k for 90 epochs, the ResNet50 performance is as follows:

Data AugmentationBaselineLabel SmoothingAgMax (Ours)
Standard76.476.876.9
CutOut76.276.577.1
MixUp76.576.777.6
CutMix76.376.477.4
AutoAugment (AA)76.276.277.1
CutOut+AA75.775.776.6
MixUp+AA75.976.577.1
CutMix+AA75.575.577.0

The figure below demonstrates consistent improvement across different data augmnentation methods:

<img src="https://github.com/roatienza/agmax/blob/master/figures/ImageNet_ResNet50_90_epochs_Top-1.png" width="50%" height="50%">

Install requirements

pip3 install -r requirements.txt

Train

For example, train ResNet50 with AgMax on 2 GPUs for 90 epochs, SGD with lr=0.1 and multistep learning rate scheduler:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=ResNet50-standard-agmax --train \
--multisteplr --dataset=imagenet --epochs=90 --save

Compare the results without AgMax:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=ResNet50-standard --train \
--multisteplr --dataset=imagenet --epochs=90 --save

Test

Using a pre-trained model:

ResNet101 trained with CutMix, AutoAugment and AgMax:

mkdir checkpoints
cd checkpoints
wget https://github.com/roatienza/agmax/releases/download/agmax-0.1.0/imagenet-agmax-mi-ResNet101-cutmix-auto_augment-81.19-mlp-4096.pth
cd ..
python3 main.py --config=ResNet101-auto_augment-cutmix-agmax --eval \
--dataset=imagenet \
--resume imagenet-agmax-mi-ResNet101-cutmix-auto_augment-81.19-mlp-4096.pth

ResNet50 trained with CutMix, AutoAugment and AgMax:

python3 main.py --config=ResNet50-auto_augment-cutmix-agmax --eval --n-units=2048 \
--dataset=imagenet --resume imagenet-agmax-ResNet50-cutmix-auto_augment-79.12-mlp-2048.pth

Other pre-trained models (Baselines):

Citation

If you find this work useful, please cite:

@inproceedings{atienza2022improving,
  title={Improving Model Generalization by Agreement of Learned Representations from Data Augmentation},
  author={Atienza, Rowel},
  booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
  pages={372--381},
  year={2022}
}