Home

Awesome

[WIP]Consistency Regularization for Semi-supervised Learning with PyTorch

This repositrory includes consistency regularization algorithms for semi-supervised learning:

Training and evaluation setting follow Oliver+ 2018 and FixMatch.

Requirements

sklean is used for moon_data_exp.py (two moons dataset experiment)

Usage

One can use sh ./scripts/DATASET_NAME/ALGORITHM.sh /PATH/TO/OUTPUT_DIR NUM_LABELS, for example, to reproduce fixmatch in CIFAR-10 with 250 labels results, run

sh ./scripts/fixmatch-setup/cifar10/fixmatch.sh ./results/cifar10-fixmatch-250labeles 250

The scripts in scripts/fixmatch-setup are for training and evaluating a model with the FixMatch setting, and the scripts in scripst/realistic-evaluation-setup are for training and evaluating a model with the Oliver+ 2018 setting.

If yor would like to train a model with own setting, please see parser.py.

NOTE: train_test.py evaluates a model performance as median of last [1, 10, 20, 50] checkpoint accuracies (FixMatch setting), and train_val_test.py evaluates the test accuracy of the best model on validation data (Oliver+ 2018 setting).

Performance

WIP

Oliver+ 2018this repo
CIFAR-10 4000 labelsSVHN 1000 labelsCIFAR-10 4000 labelsSVHN 1000 labels
Supervised20.26 ±0.3812.83 ±0.4719.8511.03
Pi-Model16.37 ±0.637.19 ±0.2714.847.87
Mean Teacher15.87 ±0.285.65 ±0.4714.285.83
VAT13.13 ±0.395.35 ±0.1912.156.38

NOTE: Our implementation is different from Oliver+ 2018 as follows:

  1. we use not only purely unlabeled data, but also labeled data as unlabeled data. (following Sohn+ 2020)
  2. our VAT implementation follows Miyato+, but Oliver+ use KLD with different directions as the loss function. see issue.
  3. parameter initialization of WRN-28. (following Sohn+ 2020)

If you would like to evaluate the model with the same conditions as Oliver+ 2018, please see this repo.

Sohn+ 2020this repo
CIFAR-10 250 labelsCIFAR-10 4000 labelsCIFAR-10 250 labelsCIFAR-10 4000 labels
UDA8.82±1.084.88±0.1810.086.32
FixMatch5.07±0.654.26±0.059.886.84

reported error rates are the median of last 20 checkpoints

Citation

@misc{suzuki2020consistency,
    author = {Teppei Suzuki},
    title = {Consistency Regularization for Semi-supervised Learning with PyTorch},
    year = {2020},
    publisher = {GitHub},
    journal = {GitHub repository},
    howpublished = {\url{https://github.com/perrying/pytorch-consistency-regularization}},
}

References