Home

Awesome

Code for 'RecursiveMix: Mixed Learning with History'

<p align="center"> <img src="figs/RM.svg" width="100%"></p>

RecursiveMix (RM), which uses the historical input-prediction-label triplet to enhance the generalization of Deep Vision Models. Paper Link Here.

Requirements

Experiment Environment

Usage

1. Train the model

For example, to reproduce the results of RM in CIFAR-10 (97.65% Top-1 acc in averaged 3 runs, logs are provided in logs/):

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 29500 main.py \
            --name 'your_experiment_log_path' \
            --model_file 'pyramidnet' \
            --model_name 'pyramidnet_200_240' \
            --data 'cifar10' \
            --data_dir '/path/to/CIFAR10' \
            --epoch 300 \
            --batch_size 64 \
            --lr 0.25 \
            --scheduler 'step' \
            --schedule 150 225 \
            --weight_decay 1e-4 \
            --nesterov \
            --num_workers 8 \
            --save_model \
            --aug 'recursive_mix' \
            --aug_alpha 0.5 \
            --aug_omega 0.1

RM in ImageNet (79.20% Top-1 acc)

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port 29500 main.py \
            --name 'your_experiment_log_path' \
            --model_file 'resnet' \
            --model_name 'resnet50' \
            --data 'imagenet' \
            --epoch 300 \
            --batch_size 512 \
            --lr 0.2 \
            --warmup 5 \
            --weight_decay 1e-4 \
            --aug_plus \
            --num_workers 32 \
            --save_model \
            --aug 'recursive_mix' \
            --aug_alpha 0.5 \
            --aug_omega 0.5

2. Test the model

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 29500 main.py \
            --name 'your_experiment_log_path' \
            --batch_size 64 \
            --model_file 'pyramidnet' \
            --model_name 'pyramidnet_200_240' \
            --data 'cifar10' \
            --data_dir '/path/to/CIFAR10' \
            --num_workers 8 \
            --evaluate \
            --resume 'best'

Model Zoo

Image Classification

BackboneSizeParams (M)Acc@1LogDownload
ResNet-5022425.5676.32log[Google] [GitHub]
+ Mixup22425.5677.42log[Google] [GitHub]
+ CutMix22425.5678.60log[Google] [GitHub]
+ RecursiveMix22425.5679.20log[Google] [GitHub]

Object Detection

ATSS

BackboneLr schdMem (GB)Inf time (fps)box APLogDownload
ResNet-501x3.719.739.4log[Google] [GitHub]
+ CutMix1x3.719.740.1log[Google] [GitHub]
+ RecursiveMix1x3.719.741.5log[Google] [GitHub]

Semantic Segmentation

UPerNet

BackboneCrop SizeLr schdMem (GB)Inf time (fps)mIoULogdownload
ResNet-50512x512800008.123.4040.40log[Google] [GitHub]
+ CutMix512x512800008.123.4041.24log[Google] [GitHub]
+ RecursiveMix512x512800008.123.4042.30log[Google] [GitHub]