Home

Awesome

Self-Adaptive Training

This is the PyTorch implementation of the

Self-adaptive training significantly improves the generalization of deep networks under noise and enhances the self-supervised representation learning. It also advances the state-of-the-art on learning with noisy label, adversarial training and the linear evaluation on the learned representation.

News

Requirements

Usage

Standard training

The main.py contains training and evaluation functions in standard training setting.

Runnable scripts

Results on CIFAR datasets under uniform label noise

Noise Rate0.20.40.60.8
ResNet-3494.1492.6489.2378.58
WRN-28-1094.8493.2389.4280.13
Noise Rate0.20.40.60.8
ResNet-3475.7771.3862.6938.72
WRN-28-1077.7172.6064.8744.17

Runnable scripts for repreducing double-descent phenomenon

You can use the command as below to train the default model (i.e., ResNet-18) on CIFAR10 dataset with 16.67% uniform label noise injected (i.e., 15% label error rate):

$ bash scripts/cifar10/run_sat_dd_parallel.sh [TRIAL_NAME]
$ bash scripts/cifar10/run_ce_dd_parallel.sh [TRIAL_NAME]

Double-descent ERM vs. single-descent self-adaptive training

<p align="center"> <img src="images/model_dd.png" width="450"\> </p> <p align="center"> Double-descent ERM vs. single-descent self-adaptive training on the error-capacity curve. The vertical dashed line represents the interpolation threshold. </p> <p align="center"> <img src="images/epoch_dd.png" width="450"\> </p> <p align="center"> Double-descent ERM vs. single-descent self-adaptive training on the epoch-capacity curve. The dashed vertical line represents the initial epoch E_s of our approach. </p>

Adversarial training

We use state-of-the-art adversarial training algorithm TRADES as our baseline. The main_adv.py contains training and evaluation functions in adversarial training setting on CIFAR10 dataset.

Training scripts

Robust evaluation script

Evaluate robust WRN-34-10 models on CIFAR10 under PGD-20 attack:

  $ python pgd_attack.py --model-dir "/path/to/checkpoints"

This command evaluates 71-st to 100-th checkpoints in the specified path.

Results

<p align="center"> <img src="images/robust_acc.png" width="450"\> </p> <p align="center"> Self-Adaptive Training mitigates the overfitting issue and consistently improves TRADES. </p>

Attack TRADES+SAT

We provide the checkpoint of our best performed model in Google Drive and compare its natural and robust accuracy with TRADES as below.

Attack (submitted by) \ MethodTRADESTRADES + SAT
None (initial entry)84.9283.48
PGD-20 (initial entry)56.6858.03
MultiTargeted-2000 (initial entry)53.2453.46
Auto-Attack+ (Francesco Croce)53.0853.29

Reference

For technical details, please check the conference version or the journal version of our paper.

@inproceedings{huang2020self,
  title={Self-Adaptive Training: beyond Empirical Risk Minimization},
  author={Huang, Lang and Zhang, Chao and Zhang, Hongyang},
  booktitle={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}

@article{huang2021self,
  title={Self-Adaptive Training: Bridging the Supervised and Self-Supervised Learning},
  author={Huang, Lang and Zhang, Chao and Zhang, Hongyang},
  journal={arXiv preprint arXiv:2101.08732},
  year={2021}
}

Contact

If you have any question about this code, feel free to open an issue or contact laynehuang@pku.edu.cn.