Home

Awesome

RDA4RobustSSL

This repo is the official Pytorch implementation of our paper:

RDA: Reciprocal Distribution Alignment for Robust Semi-supervised Learning
Authors: Yue Duan, Lei Qi, Lei Wang, Luping Zhou and Yinghuan Shi

Introduction

Reciprocal Distribution Alignment (RDA) is a semi-supervised learning (SSL) framework working with both the matched (conventionally) and the mismatched class distributions. Distribution mismatch is an often overlooked but more general SSL scenario where the labeled and the unlabeled data do not fall into the identical class distribution. This may lead to the model not exploiting the labeled data reliably and drastically degrade the performance of SSL methods, which could not be rescued by the traditional distribution alignment. RDA achieves promising performance in SSL under a variety of scenarios of mismatched distributions, as well as the conventional matched SSL setting.

<div align=center> <img width="750px" src="/figures/framework.jpg"> </div>

Requirements

How to Train

Important Args

Training with Single GPU

To better reproduce our experimental results, it is recommended to follow our experimental environment using a single GPU for training.

python train_rda.py --world-size 1 --rank 0 --gpu [0/1/...] @@@other args@@@

Training with Multi-GPUs

python train_rda.py --world-size 1 --rank 0 @@@other args@@@
python train_rda.py --world-size 1 --rank 0 --multiprocessing-distributed @@@other args@@@

Examples of Running

By default, the model and dist&index.txt will be saved in \--save_dir\--save_name. The file dist&index.txt will display detailed settings of mismatched distributions. This code assumes 1 epoch of training, but the number of iterations is 2**20. For CIFAR-100, you need set --widen_factor 8 for WRN-28-8 whereas WRN-28-2 is used for CIFAR-10. Note that you need set --net resnet18 for STL-10 and mini-ImageNet. Additionally, WRN-28-2 is used for all experiments under DARP's protocol.

Conventional Setting

Matched and balanced $C_x$, $C_u$ for Tab. 1 in Sec. 5.1

python train_rda.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 20  --gpu 0

Mismatched Distributions

Imbalanced $C_x$ and balanced $C_u$ for Tab. 2 in Sec. 5.2

python train_rda.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch rda --n0 10 --gpu 0
python train_rda.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar100 --dataset cifar100 --num_classes 100 --num_labels 400 --mismatch rda --n0 40 --gpu 0 --widen_factor 8
python train_rda.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name miniimage --dataset miniimage --num_classes 100 --num_labels 1000 --mismatch rda --n0 40 --gpu 0 --net resnet18 

Imbalanced and mismatched $C_x$, $C_u$ for Tab. 3 in Sec. 5.2

python train_rda.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch rda --n0 10 --gamma 5 --gpu 0

Balanced $C_x$ and imbalanced $C_u$ for Tab. 5 in Sec. 5.2

python train_rda.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch rda --gamma 200 --gpu 0

DARP's protocol for Tab. 5 in Sec. 5.2

python train_rda.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --mismatch darp --n0 100 --gamma 1 --gpu 0
python train_rda.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --mismatch darp_reversed --n0 100 --gamma 100 --gpu 0
python train_rda.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name stl10 --dataset stl10 --num_classes 10 --mismatch darp --n0 10 --gpu 0 --fold -1 

Resume Training and Evaluation

If you restart the training, please use --resume --load_path @your_path_to_checkpoint. Each time you start training, the evaluation results of the current model will be displayed. If you want to evaluate a model, use its checkpoints to resume training.

Results (e.g. seed=1)

<div align=center>
DatesetLabels$N_0$ / $\gamma_l$$\gamma$ / $\gamma_u$Acc (%)Note
CIFAR-1020--93.40Conventional setting
40--94.13
80--94.24
100--94.66
4010-93.06Imbalanced $C_x$ and balanced $C_u$
4020-81.51
10040-94.42
10080-78..99
4010281.60Mismatched imbalanced $C_x$ and $C_u$
4010580.68
10040579.54
40-10047.68Balanced $C_x$ and imbalanced $C_u$
40-20045.57
DARP100193.11DARP's protocol
DARP1005079.84
DARP10015074.71
DARP (reversed)10010078.53
CIFAR-10040040-33.54Imbalanced $C_x$ and balanced $C_u$
100080-42.87
STL-101000--82.53Conventional setting
DARP10-87.21DARP's protocol
DARP20-83.71
mini-ImageNet1000--47.73Conventional setting
100040-43.59Imbalanced $C_x$ and balanced $C_u$
100080-38.16
1000401025.91Mismatched imbalanced $C_x$ and $C_u$
</div>

Citation

Please cite our paper if you find RDA useful:

@inproceedings{duan2022rda,
  title={RDA: Reciprocal Distribution Alignment for Robust Semi-supervised Learning},
  author={Duan, Yue and Qi, Lei and Wang, Lei and Zhou, Luping and Shi, Yinghuan},
  booktitle={European Conference on Computer Vision},
  pages={533--549},
  year={2022},
  organization={Springer}
}

or

@article{duan2022rda,
  title={RDA: Reciprocal Distribution Alignment for Robust Semi-supervised Learning},
  author={Duan, Yue and Qi, Lei and Wang, Lei and Zhou, Luping and Shi, Yinghuan},
  journal={arXiv preprint arXiv:2208.04619},
  year={2022}
}

Acknowledgement

Our code is based on open source code: LeeDoYup/FixMatch-pytorch.