

Code for 'RecursiveMix: Mixed Learning with History'

<p align="center"> <img src="figs/RM.svg" width="100%"></p>

RecursiveMix (RM), which uses the historical input-prediction-label triplet to enhance the generalization of Deep Vision Models. Paper Link Here.


Experiment Environment


1. Train the model

For example, to reproduce the results of RM in CIFAR-10 (97.65% Top-1 acc in averaged 3 runs, logs are provided in logs/):

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 29500 main.py \
            --name 'your_experiment_log_path' \
            --model_file 'pyramidnet' \
            --model_name 'pyramidnet_200_240' \
            --data 'cifar10' \
            --data_dir '/path/to/CIFAR10' \
            --epoch 300 \
            --batch_size 64 \
            --lr 0.25 \
            --scheduler 'step' \
            --schedule 150 225 \
            --weight_decay 1e-4 \
            --nesterov \
            --num_workers 8 \
            --save_model \
            --aug 'recursive_mix' \
            --aug_alpha 0.5 \
            --aug_omega 0.1

RM in ImageNet (79.20% Top-1 acc)

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port 29500 main.py \
            --name 'your_experiment_log_path' \
            --model_file 'resnet' \
            --model_name 'resnet50' \
            --data 'imagenet' \
            --epoch 300 \
            --batch_size 512 \
            --lr 0.2 \
            --warmup 5 \
            --weight_decay 1e-4 \
            --aug_plus \
            --num_workers 32 \
            --save_model \
            --aug 'recursive_mix' \
            --aug_alpha 0.5 \
            --aug_omega 0.5

2. Test the model

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 29500 main.py \
            --name 'your_experiment_log_path' \
            --batch_size 64 \
            --model_file 'pyramidnet' \
            --model_name 'pyramidnet_200_240' \
            --data 'cifar10' \
            --data_dir '/path/to/CIFAR10' \
            --num_workers 8 \
            --evaluate \
            --resume 'best'

Model Zoo

Image Classification

BackboneSizeParams (M)Acc@1LogDownload
ResNet-5022425.5676.32log[Google] [GitHub]
+ Mixup22425.5677.42log[Google] [GitHub]
+ CutMix22425.5678.60log[Google] [GitHub]
+ RecursiveMix22425.5679.20log[Google] [GitHub]

Object Detection


BackboneLr schdMem (GB)Inf time (fps)box APLogDownload
ResNet-501x3.719.739.4log[Google] [GitHub]
+ CutMix1x3.719.740.1log[Google] [GitHub]
+ RecursiveMix1x3.719.741.5log[Google] [GitHub]

Semantic Segmentation


BackboneCrop SizeLr schdMem (GB)Inf time (fps)mIoULogdownload
ResNet-50512x512800008.123.4040.40log[Google] [GitHub]
+ CutMix512x512800008.123.4041.24log[Google] [GitHub]
+ RecursiveMix512x512800008.123.4042.30log[Google] [GitHub]