Home

Awesome

DABS: A Domain Agnostic Benchmark for Self-Supervised Learning

This repository contains the code for DABS, a benchmark for domain-agnostic self-supervised learning algorithms. The basic components of the benchmark can be found in datasets, encoders, and algorithms. Training is implemented with the PyTorch Lightning framework, logging with Weights and Biases, and configuration management with Hydra.

Updates

March 2023

Jan 2023

Usage

We provide support for Python >= 3.8. Install requirements with

python -m pip install -r requirements.txt

For instructions on how to install PyTorch versions compatible with your CUDA versions, see pytorch.org. We support Torch 1.6.0 but later versions may work as well.

Datasets

We provide a set of dataset implementations (in src/datasets) from image, text, speech, sensor, medical imaging, and image-text domains. Preprocessing operations on these datasets are minimal and hard-coded as simple resizing (i.e. of images) and truncations (i.e. of text, audio). These should not be changed so as to maintain fair comparisons across other users of the benchmark.

See conf/datasets/*.yaml for all dataset configs, including the loss, metrics, and batch size used for each dataset.

Almost all datasets will download automatically when the dataset class is instantiated. The exceptions are the CheXpert, ImageNet, and CU Birds datasets, where manual registration or download is required. See the respective dataset files for specific instructions.

Pretraining Dataset (unlabeled)Transfer Dataset (labeled)
ImageNetAircraft, CIFAR10, CU Birds, DTD, Traffic Sign, VGG Flower
PAMAP2PAMAP2
MSCOCOMSCOCO (mismatched detection), VQA (Binary classification)
Wikitext-103GLUE (10 Tasks)
mC4PAWS-X (7 Tasks)
CheXpertCheXpert (atelectasis, cardiomegaly, consolidation, edema, and pleural effusion), ChestX-ray8 (atelectasis, cardiomegaly, effusion, infiltration, mass, nodule, pneumonia, pneumothorax)
LibriSpeechAudio MNIST, Fluent Speech (Action, Object, Location), Google Speech Commands, LibriSpeech, VoxCeleb1

Pretraining

During the pretraining phase, self-supervised encoders are trained to learn good representations from unlabeled data. We currently support seven datasets for pretraining, one for each domain: MS COCO, ImageNet, CheXpert, PAMAP2, mC4, WikiText-103, and LibriSpeech. If the pretraining dataset has associated labels, an online linear evaluator is jointly trained with the encoder to provide a heuristic of transfer performance.

Run pretraining with commands like

python pretrain.py exp.name=<experiment-name> dataset=<dataset> algorithm=<algorithm>

Each dataset and encoder has its own config file, so to train a Transformer on the CheXpert dataset with the e-Mix algorithm, run

python pretrain.py exp.name=emix-chexpert encoder=transformer dataset=chexpert algorithm=emix

See conf/pretrain.yaml for all pretraining configuration fields.

For more information on the datasets, encoders, and algorithms, see the following section.

Pretraining DatasetModalityLabel type (unused)Input Type
CIFAR10Natural imagesSingle label2d
PAMAP2SensorSingle label2d
MSCOCOCaptioned imagesSingle label2d +<br/>tokens
WikiText-103English TextNo labeltokens
mC4Multilingual TextNo labeltokens
CheXpertMedical imagesMulti label2d
LibriSpeechSpeechNo label2d

Transfer Learning

After pretraining, a small linear classifier is trained on top of the frozen encoder. Run transfer learning from a randomly initialized encoder with

python transfer.py exp.name=<experiment-name> dataset=<dataset> ckpt=null 

See conf/transfer.yaml for all transfer learning configuration fields and optionally replace null with the path to your pretrained encoder checkpoint.

DatasetModalityLabel typeEvaluation metricInput Type
AircraftNatural imagesSingle labelAccuracy2d
CU BirdsNatural imagesSingle labelAccuracy2d
DTDNatural imagesSingle labelAccuracy2d
Traffic SignNatural imagesSingle labelAccuracy2d
VGG FlowerNatural imagesSingle labelAccuracy2d
Pamap2SensorSingle labelAccuracy2d
MS COCOCaptioned imagesBinary labelAccuracy2d +<br/>tokens
VQACaptioned imagesBinary labelAccuracy2d +<br/>tokens
CheXpertMedical imagesMulti labelAUROC2d
ChestX-ray8Medical imagesMulti labelAUROC2d
PAWS-XMultilingual TextBinary labelAccuracytokens
COLAEnglish TextBinary labelPearson correlationtokens
MNLI MatchedEnglish TextSingle labelAccuracytokens
MNLI MismatchedEnglish TextSingle labelAccuracytokens
MRPCEnglish TextBinary labelAccuracytokens
QNLIEnglish TextBinary labelAccuracytokens
QQPEnglish TextBinary labelAccuracytokens
RTEEnglish TextBinary labelAccuracytokens
SST2English TextBinary labelAccuracytokens
STSBEnglish TextRegressionSpearman correlationtokens
WNLIEnglish TextBinary labelAccuracytokens
Audio MNISTSpeechSingle labelAccuracy2d
Fluent SpeechSpeechSingle labelAccuracy2d
Google Speech CommandsSpeechSingle labelAccuracy2d
LibriSpeechSpeechSingle labelAccuracy2d
VoxCeleb1SpeechSingle labelAccuracy2d

Encoders

A domain-agnostic SSL method should have an encoder which remains as constant as possible across domains. We provide a general transformer encoder baseline (in src/encoders). The transformer operates on a sequence of vectors that are produced by a small set of embedding modules (e.g. patch or token embeddings).

Pretraining algorithms

The pretraining algorithm is the framework and objective that the encoder is trained with. Examples of domain-specific algorithms include SimCLR, BYOL, and MoCo, but these are not domain-agnostic methods as they depend on vision-specific augmentations. We provide our own domain-agnostic implementations of recent algorithms, including e-mix (a generalization of i-mix) and Shuffled Embedding Detection (ShED; a generalization of ELECTRA), which randomly permutes a subset of the input embeddings and trains the model to identify the permuted embeddings.

Results

Below are results for algorithms trained on each dataset in DABS. The baseline performance is obtained via a randomly initialized encoder.

Pretrain DatasetTransfer DatasetEncoderBaseline Performancee-mix PerformanceShED Performance
ImageNetCIFAR10Transformer24.20%39.43%39.63%
ImageNetCU BirdsTransformer1.62%3.86%2.95%
ImageNetVGG FlowersTransformer9.03%25.96%13.03%
ImageNetDTDTransformer7.39%8.83%18.35%
ImageNetTraffic SignTransformer14.33%65.07%27.51%
ImageNetAircraftTransformer2.70%10.15%5.60%
PAMAP2PAMAP2Transformer69.81%79.48%88.69%
MSCOCOVQATransformer53.38%58.77%54.25%
MSCOCOMismatched CaptionTransformer49.41%49.86%52.60%
CheXpertCheXpertTransformer68.14%72.40%72.40%
CheXpertChestX-ray8Transformer57.00%63.00%63.70%
Wikitext-103GLUE (average)Transformer42.29%44.08%48.37%
mC4PAWS-X (average)Transformer58.11%56.16%59.91%
LibriSpeechAudio MNISTTransformer33.13%80.35%67.33%
LibriSpeechFluent LocationsTransformer62.09%60.93%60.24%
LibriSpeechFluent ActionsTransformer26.15%29.87%30.53%
LibriSpeechFluent ObjectsTransformer30.13%39.89%39.36%
LibriSpeechGoogle Speech CommandsTransformer4.87%19.22%20.73%
LibriSpeechLibriSpeechTransformer17.12%60.18%34.77%
LibriSpeechVoxCeleb1Transformer0.59%2.43%2.81%