Home

Awesome

AUM

Pytorch Library for Area Under the Margin (AUM) Ranking, as proposed in the paper: Identifying Mislabeled Data using the Area Under the Margin Ranking

Install

pip install -U aum

Usage

Instantiate an AUMCalculator object:

from aum import AUMCalculator

save_dir = '~/Desktop'
aum_calculator = AUMCalculator(save_dir, compressed=True)

Note: you can set compressed to False if you want to store the AUM metrics at every call to the update method. This will require considerably more space, however.

You can then update aum rankings on batches of data during training with:

model.train()
for batch in loader:
    inputs, targets, sample_ids = batch

    logits = model(inputs)

    records = aum_calculator.update(logits, targets, sample_ids)

    ...

records is a dictionary mapping a sample_id to an AUMRecord containing the information below, including the AUM for the sample at this point in time.

@dataclass
class AUMRecord:
    """
    Class for holding info around an aum update for a single sample
    """
    sample_id: Optional[int, str]
    num_measurements: int
    target_logit: int
    target_val: float
    other_logit: int
    other_val: float
    margin: float
    aum: float

And once you are done training, you can generate a csv of ranked samples with their aum scores with:

aum_calculator.finalize()

If you have a dataset that does not return sample_ids, you can wrap it in DatasetWithIndex. The last element of the tuple returned for a given sample will be its sample_id.

from aum import DatasetWithIndex
from torch.utils.data import Dataset

my_dataset = Dataset(...)
my_dataset_with_index = DatasetWithIndex(my_dataset)

Example Outputs

Calling finalize() on an AUMCalculator will result in the creation of 1 or 2 csv files, depending if compressed was set to True or False.

If AUMCalculator was instantiated with compressed = True, you will find a csv file titled aum_values.csv in the following format:

sample_idaum
sample_11.205
sample_31.145
sample_2-3.785

If AUMCalculator was instantiated with compressed = False, you will find a csv file titled full_aum_records.csv in addition to the aum_values.csv. full_aum_records.csv is in the following format:

sample_idnum_measurementstarget_logittarget_valother_logitother_valmarginaum
sample_1103.74102.481.261.26
sample_1204.59103.441.151.205
sample_211-0.0903.11-3.20-3.02
sample_221-1.1203.25-4.37-3.785
sample_3163.39101.621.771.77
sample_3262.6322.110.521.145

Replicate results from the paper

To replicate results, please refer to the examples/paper_replication section.

Example usage

For a more basic example of using the AUMCalculator and DatasetWithIndex in a training script, please refer to the examples/cifar100 section.

Cite

@article{pleiss2020identifying,
  title={Identifying Mislabeled Data using the Area Under the Margin Ranking},
  author={Geoff Pleiss and Tianyi Zhang and Ethan R. Elenberg and Kilian Q. Weinberger},
  journal={arXiv preprint arXiv:2001.10528},
  year={2020}
}