Home

Awesome

AutoWS-Bench-101 🪑

Introduction

AutoWS-Bench-101 is a framework for evaluating automated WS (AutoWS) techniques in challenging WS settings---a set of diverse application domains on which it has been previously difficult or impossible to apply traditional WS techniques.

Installation

Install anaconda: Instructions here: https://www.anaconda.com/download/

Clone the repository:

git clone https://github.com/Kaylee0501/AutoWS-Bench-101.git
cd AutoWS-Bench-101

Create virtual environment:

conda env create -f environment.yml
source activate AutoWS-Bench-101

Install CLIP:

pip install git+https://github.com/openai/CLIP.git

Datasets

Our benchmark automatic download the dataset for you. Please run data_settings.py to download the specific dataset you need.

Name# class# train# valid# test
MNIST1057000300010000
FashionMNIST1057000300010000
KMNIST1057000300010000
CIFAR101047500250010000
SphericalMNIST1057000300010000
PermutedMNIST1057000300010000
ECG42802691475233494
EMBER228500015000100000
NavierStokes21001001900

Run the Experiment

To run the whole implementation, we provide a pipeline. This pipeline will walk you through a full example of our framework. It allows you to choose the datasets and automatic download for you, select the embeddings, and generate a bunch of labeling functions (LFs) with the LF selectors you preferred. It then trains a Snorkel label model and gives you the accuracy and coverage.

Please run (feel free to change the argument):

python fwrench/applications/pipeline.py --dataset='cifar10' --embedding='openai' --lf_selector='iws'
python fwrench/applications/pipeline.py --dataset='ember' --embedding='pca' --lf_selector='snuba'

Examples

Training MNIST with pca embedding and snuba selector

import logging
import random
import copy

import fire
import fwrench.embeddings as feats
import fwrench.utils.autows as autows
import fwrench.utils.data_settings as settings
import numpy as np
import torch
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score
from wrench.logging import LoggingHandler

def main(
    dataset="mnist",
    dataset_home="./datasets",
    embedding="pca",  # raw | pca | resnet18 | vae

    lf_selector="snuba",  # snuba | interactive | goggles
    em_hard_labels=False,  # Use hard or soft labels for end model training
    n_labeled_points=100,  # Number of points used to train lf_selector
    #
    # Snuba options
    snuba_combo_samples=-1,  # -1 uses all feat. combos
    # TODO this needs to work for Snuba and IWS
    snuba_cardinality=1,  # Only used if lf_selector='snuba'
    snuba_iterations=23,
    lf_class_options="default",  # default | comma separated list of lf classes to use in the selection procedure. Example: 'DecisionTreeClassifier,LogisticRegression'
    seed=123,
    prompt=None,
):
    ################ HOUSEKEEPING/SELF-CARE 😊 ################################
    random.seed(seed)
    logging.basicConfig(
        format="%(asctime)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[LoggingHandler()],
    )
    logger = logging.getLogger(__name__)
    device = torch.device("cuda")

    ################ LOAD DATASET 😊  ##########################################
    train_data, valid_data, test_data, k_cls, model = settings.get_mnist(
            n_labeled_points, dataset_home
    )
    emb = PCA(n_components=100)
    embedder = feats.SklearnEmbedding(emb)

    embedder.fit(train_data, valid_data, test_data)
    train_data_embed = embedder.transform(train_data)
    valid_data_embed = embedder.transform(valid_data)
    test_data_embed = embedder.transform(test_data)

    ################ AUTOMATED WEAK SUPERVISION ###############################
    test_covered, hard_labels, soft_labels = autows.run_snuba(
        valid_data, train_data, test_data, valid_data_embed,
        train_data_embed, test_data_embed, snuba_cardinality,
        snuba_combo_samples, snuba_iterations, lf_class_options,
        k_cls, logger,
    )
    acc = accuracy_score(test_covered.labels, hard_labels)
    cov = float(len(test_covered.labels)) / float(len(test_data.labels))
    logger.info(f"label model train acc:    {acc}")
    logger.info(f"label model coverage:     {cov}")
    return acc

if __name__ == "__main__":
    fire.Fire(main)

Training ECG with resnet18 embedding and goggles selector

import logging
import random
import copy

import fire
import fwrench.embeddings as feats
import fwrench.utils.autows as autows
import fwrench.utils.data_settings as settings
import numpy as np
import torch
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score
from wrench.logging import LoggingHandler

def main(
    dataset="ecg",
    dataset_home="./datasets",
    embedding="resnet18",  # raw | pca | resnet18 | vae

    lf_selector="goggles",  # snuba | interactive | goggles
    em_hard_labels=False,  # Use hard or soft labels for end model training
    n_labeled_points=100,  # Number of points used to train lf_selector
    #
    lf_class_options="default",  # default | comma separated list of lf classes to use in the selection procedure. Example: 'DecisionTreeClassifier,LogisticRegression'
    seed=123,
    prompt=None,
):
    ################ HOUSEKEEPING/SELF-CARE 😊 ################################
    random.seed(seed)
    logging.basicConfig(
        format="%(asctime)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[LoggingHandler()],
    )
    logger = logging.getLogger(__name__)
    device = torch.device("cuda")

    ################ LOAD DATASET #############################################
    train_data, valid_data, test_data, k_cls, model = settings.get_ecg(
            n_labeled_points, dataset_home
    )
    embedder = feats.ResNet18Embedding(dataset)

    embedder.fit(valid_data, test_data)
    valid_data_embed = embedder.transform(valid_data)
    test_data_embed = embedder.transform(test_data)
    train_data_embed = copy.deepcopy(valid_data_embed)
    train_data = copy.deepcopy(valid_data)

    ################ AUTOMATED WEAK SUPERVISION ###############################
    test_covered, hard_labels, soft_labels = autows.run_goggles(
        valid_data, train_data, test_data, valid_data_embed,
        train_data_embed, test_data_embed, logger,
    )
    acc = accuracy_score(test_covered.labels, hard_labels)
    cov = float(len(test_covered.labels)) / float(len(test_data.labels))
    logger.info(f"label model train acc:    {acc}")
    logger.info(f"label model coverage:     {cov}")
    return acc

if __name__ == "__main__":
    fire.Fire(main)

Credits

We extend the WRENCH codebase to build our framework. Thanks for their inspiration.