Awesome
PRG4SSL-MNAR
This repo is the official Pytorch implementation of our paper:
Towards Semi-supervised Learning with Non-random Missing Labels
Authors: Yue Duan, Zhen Zhao, Lei Qi, Lei Wang, Luping Zhou and Yinghuan Shi
- π Quick links:
- Code Download
- [PDF/Abs-arXiv | PDF/Abs-Published | Poster | ζη« θ§£θ―»-η₯δΉ(Zhihu)]
- π° Latest news:
- We write a detailed explanation (in chinese) of this work on η₯δΉ(Zhihu).
- Our paper is accepted by IEEE/CVF International Conference on Computer Vision (ICCV) 2023 ππ. Thanks to users.
- π Related works:
- π [MOST RELEVANT] Interested in robust SSL in MNAR setting with mismatched distributions? π Check out our ECCV'22 paper RDA [PDF-arXiv | Code].
- π [LATEST] Interested in the cross-modal retrieval with noisy correspondence? π Check out our ACMMM'24 paper PC2 [PDF-arXiv | Code].
- [SSL] Interested in the SSL in fine-grained visual classification (SS-FGVC)? π Check out our AAAI'24 paper SoC [PDF-arXiv | Code].
- [SSL] Interested in the conventional SSL or more application of complementary label in SSL? π Check out our TNNLS paper MutexMatch [PDF-arXiv | Code].
Introduction
Semi-supervised learning (SSL) tackles the label missing problem by enabling the effective usage of unlabeled data. While existing SSL methods focus on the traditional setting, a practical and challenging scenario called label Missing Not At Random (MNAR) is usually ignored. In MNAR, the labeled and unlabeled data fall into different class distributions resulting in biased label imputation, which deteriorates the performance of SSL models. In this work, class transition tracking based Pseudo-Rectifying Guidance (PRG) is devised for MNAR. We explore the class-level guidance information obtained by the Markov random walk, which is modeled on a dynamically created graph built over the class tracking matrix. PRG unifies the history information of each class transition caused by the pseudo-rectifying procedure to activate the model's enthusiasm for neglected classes, so as the quality of pseudo-labels on both popular classes and rare classes in MNAR could be improved.
<div align=center> <img width="750px" src="/figures/framework.Jpeg"> </div>Requirements
- numpy==1.21.6
- pandas==1.3.2
- Pillow==10.0.0
- scikit_learn==1.3.0
- torch==1.8.0
- torchvision==0.9.0
How to Train
Important Args
--last
: Set this flag to use the model of $\textrm{PRG}^{\textrm{Last}}$.--alpha
: class invariance coefficient. By default,--alpha 1
is set. When set--last
, please set--alpha 3
.--nb
: Number of tracked bathches.--mismatch [none/prg/cadr/darp/darp_reversed]
: Select the MNAR protocol.none
means the conventional balanced setting. See Sec. 4 in our paper for the details of MNAR protocols.--n0
: When--mismatch prg
, this arg means the imbalanced ratio $N_0$ for labeled data; When--mismatch [darp/darp_reversed]
, this arg means the imbalanced ratio $\gamma_l$ for labeled data.--gamma
: When--mismatch cadr
, this arg means the imbalanced ratio $\gamma$ for labeled data. When--mismatch prg
, this arg means the imbalanced ratio $\gamma$ for unlabeled data; When--mismatch DARP/DARP_reversed
, this arg means the imbalanced ratio $\gamma_u$ for unlabeled data.--num_labels
: Amount of labeled data used in conventional balanced setting.--net
: By default, Wide ResNet (WRN-28-2) are used for experiments. If you want to use other backbones for tarining, set--net [resnet18/preresnet/cnn13]
. We provide alternatives as follows: ResNet-18, PreAct ResNet and CNN-13.--dataset [cifar10/cifar100/miniimage]
and--data_dir
: Your dataset name and path.--num_eval_iter
: After how many iterations, we evaluate the model. Note that although we show the accuracy of pseudo-labels on unlabeled data in the evaluation, this is only to show the training process. We did not use any information about labels for unlabeled data in the training.
Training with Single GPU
We recommend using a single GPU for training to better reproduce our results. Multi-GPU training is feasible, but our results are all obtained from single GPU training.
python train_prg.py --world-size 1 --rank 0 --gpu [0/1/...] @@@other args@@@
Training with Multi-GPUs
- Using DataParallel
python train_prg.py --world-size 1 --rank 0 @@@other args@@@
- Using DistributedDataParallel with single node
python train_prg.py --world-size 1 --rank 0 --multiprocessing-distributed @@@other args@@@
Examples of Running
By default, the model and dist&index.txt
will be saved in \--save_dir\--save_name
. The file dist&index.txt
will display detailed settings of MNAR. This code assumes 1 epoch of training, but the number of iterations is 2**20. For CIFAR-100, you need set --widen_factor 8
for WRN-28-8 whereas WRN-28-2 is used for CIFAR-10. Note that you need set --net resnet18
for mini-ImageNet.
MNAR Settings
CADR's protocol in Tab. 1
- CIFAR-10 with $\gamma=20$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch cadr --gamma 20 --gpu 0
- CIFAR-100 with $\gamma=50$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar100 --dataset cifar100 --num_classes 100 --num_labels 400 --mismatch cadr --gamma 50 --gpu 0 --widen_factor 8
- mini-ImageNet with $\gamma=50$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name miniimage --dataset miniimage --num_classes 100 --num_labels 1000 --mismatch cadr --gamma 50 --gpu 0 --net resnet18
Our protocol in Tab. 2
- CIFAR-10 with 40 labels and $N_0=10$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch prg --n0 10 --gpu 0
- CIFAR-100 with 400 labels and $N_0=40$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar100 --dataset cifar100 --num_classes 100 --num_labels 400 --mismatch prg --n0 40 --gpu 0 --widen_factor 8
- mini-ImageNet with 1000 labels and $N_0=40$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name miniimage --dataset miniimage --num_classes 100 --num_labels 1000 --mismatch prg --n0 40 --gpu 0 --net resnet18
Our protocol in Fig. 6(a)
- CIFAR-10 with 40 labels, $N_0=10$ and $\gamma=5$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch prg --n0 10 --gamma 5 --gpu 0
Our protocol in Tab. 10
- CIFAR-10 with 40 labels and $\gamma=20$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --mismatch prg --gamma 20 --gpu 0
DARP's protocol in Fig. 6(a)
- CIFAR-10 with $\gamma_l=100$ and $\gamma_u=1$
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --mismatch darp --n0 100 --gamma 1 --gpu 0
- CIFAR-10 with $\gamma_l=100$ and $\gamma_u=100$ (reversed)
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --mismatch darp_reversed --n0 100 --gamma 100 --gpu 0
Conventional Setting
Matched and balanced distribution in Tab. 11
- CIFAR-10 with 40 labels
python train_prg.py --world-size 1 --rank 0 --lr_decay cos --seed 1 --num_eval_iter 1000 --overwrite --save_name cifar10 --dataset cifar10 --num_classes 10 --num_labels 40 --gpu 0
Resume Training and Evaluation
If you restart the training, please use --resume --load_path @your_weight_path
.
For evaluation, run
python eval_prg.py --load_path @your_weight_path --dataset [cifar10/cifar100/miniimage] --data_dir @your_dataset_path --num_classes @number_of_classes
By default, WideResNet-28-2 backbone is used for CIFAR-10. Use --widen-factor 8
(i.e., WideResNet-28-8) for CIFAR-100 and --net resnet18
for mini-ImageNet.
Results (e.g., seed=1)
Dateset | Labels | N0 | gamma | Acc | Setting | Method | Weight |
---|---|---|---|---|---|---|---|
CIFAR-10 | 40 | - | - | 94.05 | Conventional settings | PRG | here |
250 | - | - | 94.36 | here | |||
4000 | - | - | 95.48 | here | |||
40 | - | - | 93.79 | Conventional settings | PRG^Last | here | |
250 | - | - | 94.76 | here | |||
4000 | - | - | 95.75 | here | |||
- | - | 20 | 94.04 | CADR's protocol | PRG | here | |
- | - | 50 | 93.78 | here | |||
- | - | 100 | 94.51 | here | |||
- | - | 20 | 94.74 | CADR's protocol | PRG^Last | here | |
- | - | 50 | 94.74 | here | |||
- | - | 100 | 94.75 | here | |||
40 | 10 | - | 93.81 | Ours protocol | PRG | here | |
40 | 20 | - | 93.39 | here | |||
40 | 10 | 2 | 90.25 | here | |||
40 | 10 | 5 | 82.84 | here | |||
100 | 40 | 5 | 79.58 | here | |||
100 | 40 | 10 | 78.61 | here | |||
250 | 100 | - | 93.76 | here | |||
250 | 200 | - | 91.65 | here | |||
40 | 10 | - | 91.59 | Ours protocol | PRG^Last | here | |
40 | 20 | - | 80.31 | here | |||
250 | 100 | - | 91.36 | here | |||
250 | 200 | - | 62.16 | here | |||
DARP | 100 | 1 | 94.41 | DARP's protocol | PRG | here | |
DARP | 100 | 50 | 78.28 | here | |||
DARP | 100 | 150 | 75.21 | here | |||
DARP (reversed) | 100 | 100 | 80.86 | here | |||
CIFAR-100 | 400 | - | - | 48.70 | Conventional settings | PRG | here |
2500 | - | - | 69.81 | here | |||
10000 | - | - | 76.91 | here | |||
400 | - | - | 48.66 | Conventional settings | PRG^Last | here | |
2500 | - | - | 70.03 | here | |||
10000 | - | - | 76.93 | here | |||
- | - | 50 | 58.57 | CADR's protocol | PRG | here | |
- | - | 100 | 62.28 | here | |||
- | - | 200 | 59.33 | here | |||
- | - | 50 | 60.32 | CADR's protocol | PRG^Last | here | |
- | - | 100 | 62.13 | here | |||
- | - | 200 | 58.70 | here | |||
2500 | 100 | - | 57.56 | Ours protocol | PRG | here | |
2500 | 200 | - | 51.21 | here | |||
2500 | 100 | - | 59.40 | Ours protocol | PRG^Last | here | |
2500 | 200 | - | 42.09 | here | |||
mini-ImageNet | 1000 | - | - | 45.74 | Conventional settings | PRG | here |
1000 | - | - | 48.63 | Conventional settings | PRG^Last | here | |
- | - | 50 | 43.74 | CADR's protocol | PRG | here | |
- | - | 100 | 43.74 | here | |||
- | - | 50 | 42.22 | CADR's protocol | PRG^Last | here | |
- | - | 100 | 43.74 | here | |||
1000 | 40 | - | 40.75 | Ours protocol | PRG | here | |
1000 | 80 | - | 35.86 | here | |||
1000 | 40 | - | 39.79 | Ours protocol | PRG^Last | here | |
1000 | 80 | - | 32.64 | here |
Citation
Please cite our paper if you find PRG useful:
@inproceedings{duan2023towards,
title={Towards Semi-supervised Learning with Non-random Missing Labels},
author={Duan, Yue and Zhao, Zhen and Qi, Lei and Zhou, Luping and Wang, Lei and Shi, Yinghuan},
booktitle={IEEE/CVF International Conference on Computer Vision},
year={2023}
}
or
@article{duan2023towards,
title={Towards Semi-supervised Learning with Non-random Missing Labels},
author={Duan, Yue and Zhao, Zhen and Qi, Lei and Zhou, Luping and Wang, Lei and Shi, Yinghuan},
journal={arXiv preprint arXiv:2308.08872},
year={2023}
}