Home

Awesome

PRG4SSL-MNAR

This repo is the official Pytorch implementation of our paper:

Towards Semi-supervised Learning with Non-random Missing Labels
Authors: Yue Duan, Zhen Zhao, Lei Qi, Lei Wang, Luping Zhou and Yinghuan Shi

Introduction

Semi-supervised learning (SSL) tackles the label missing problem by enabling the effective usage of unlabeled data. While existing SSL methods focus on the traditional setting, a practical and challenging scenario called label Missing Not At Random (MNAR) is usually ignored. In MNAR, the labeled and unlabeled data fall into different class distributions resulting in biased label imputation, which deteriorates the performance of SSL models. In this work, class transition tracking based Pseudo-Rectifying Guidance (PRG) is devised for MNAR. We explore the class-level guidance information obtained by the Markov random walk, which is modeled on a dynamically created graph built over the class tracking matrix. PRG unifies the history information of each class transition caused by the pseudo-rectifying procedure to activate the model's enthusiasm for neglected classes, so as the quality of pseudo-labels on both popular classes and rare classes in MNAR could be improved.

<div align=center> <img width="750px" src="/figures/framework.Jpeg"> </div>

Requirements

How to Train

Important Args

Training with Single GPU

We recommend using a single GPU for training to better reproduce our results. Multi-GPU training is feasible, but our results are all obtained from single GPU training.

python train_prg.py --world-size 1 --rank 0 --gpu [0/1/...] @@@other args@@@

Training with Multi-GPUs

python train_prg.py --world-size 1 --rank 0 @@@other args@@@
python train_prg.py --world-size 1 --rank 0 --multiprocessing-distributed @@@other args@@@

Examples of Running

By default, the model and dist&index.txt will be saved in \--save_dir\--save_name. The file dist&index.txt will display detailed settings of MNAR. This code assumes 1 epoch of training, but the number of iterations is 2**20. For CIFAR-100, you need set --widen_factor 8 for WRN-28-8 whereas WRN-28-2 is used for CIFAR-10. Note that you need set --net resnet18 for mini-ImageNet.

MNAR Settings

CADR's protocol in Tab. 1

python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch cadr --gamma 20 --gpu 0
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar100 --dataset cifar100 --num_classes 100 --num_labels 400 --mismatch cadr --gamma 50 --gpu 0 --widen_factor 8
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name miniimage --dataset miniimage --num_classes 100 --num_labels 1000 --mismatch cadr --gamma 50 --gpu 0 --net resnet18 

Our protocol in Tab. 2

python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch prg --n0 10 --gpu 0
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar100 --dataset cifar100 --num_classes 100 --num_labels 400 --mismatch prg --n0 40 --gpu 0 --widen_factor 8
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name miniimage --dataset miniimage --num_classes 100 --num_labels 1000 --mismatch prg --n0 40 --gpu 0 --net resnet18 

Our protocol in Fig. 6(a)

python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch prg --n0 10 --gamma 5 --gpu 0

Our protocol in Tab. 10

python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch prg --gamma 20 --gpu 0

DARP's protocol in Fig. 6(a)

python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --mismatch darp --n0 100 --gamma 1 --gpu 0
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --mismatch darp_reversed --n0 100 --gamma 100 --gpu 0

Conventional Setting

Matched and balanced distribution in Tab. 11

python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40  --gpu 0

Resume Training and Evaluation

If you restart the training, please use --resume --load_path @your_weight_path.

For evaluation, run

python eval_prg.py --load_path @your_weight_path --dataset [cifar10/cifar100/miniimage] --data_dir @your_dataset_path --num_classes @number_of_classes

By default, WideResNet-28-2 backbone is used for CIFAR-10. Use --widen-factor 8 (i.e., WideResNet-28-8) for CIFAR-100 and --net resnet18 for mini-ImageNet.

Results (e.g., seed=1)

DatesetLabelsN0gammaAccSettingMethodWeight
CIFAR-1040--94.05Conventional settingsPRGhere
250--94.36here
4000--95.48here
40--93.79Conventional settingsPRG^Lasthere
250--94.76here
4000--95.75here
--2094.04CADR's protocolPRGhere
--5093.78here
--10094.51here
--2094.74CADR's protocolPRG^Lasthere
--5094.74here
--10094.75here
4010-93.81Ours protocolPRGhere
4020-93.39here
4010290.25here
4010582.84here
10040579.58here
100401078.61here
250100-93.76here
250200-91.65here
4010-91.59Ours protocolPRG^Lasthere
4020-80.31here
250100-91.36here
250200-62.16here
DARP100194.41DARP's protocolPRGhere
DARP1005078.28here
DARP10015075.21here
DARP (reversed)10010080.86here
CIFAR-100400--48.70Conventional settingsPRGhere
2500--69.81here
10000--76.91here
400--48.66Conventional settingsPRG^Lasthere
2500--70.03here
10000--76.93here
--5058.57CADR's protocolPRGhere
--10062.28here
--20059.33here
--5060.32CADR's protocolPRG^Lasthere
--10062.13here
--20058.70here
2500100-57.56Ours protocolPRGhere
2500200-51.21here
2500100-59.40Ours protocolPRG^Lasthere
2500200-42.09here
mini-ImageNet1000--45.74Conventional settingsPRGhere
1000--48.63Conventional settingsPRG^Lasthere
--5043.74CADR's protocolPRGhere
--10043.74here
--5042.22CADR's protocolPRG^Lasthere
--10043.74here
100040-40.75Ours protocolPRGhere
100080-35.86here
100040-39.79Ours protocolPRG^Lasthere
100080-32.64here

Citation

Please cite our paper if you find PRG useful:

@inproceedings{duan2023towards,
  title={Towards Semi-supervised Learning with Non-random Missing Labels},
  author={Duan, Yue and Zhao, Zhen and Qi, Lei and Zhou, Luping and Wang, Lei and Shi, Yinghuan},
  booktitle={IEEE/CVF International Conference on Computer Vision},
  year={2023}
}

or

@article{duan2023towards,
  title={Towards Semi-supervised Learning with Non-random Missing Labels},
  author={Duan, Yue and Zhao, Zhen and Qi, Lei and Zhou, Luping and Wang, Lei and Shi, Yinghuan},
  journal={arXiv preprint arXiv:2308.08872},
  year={2023}
}