

<h1 style="text-align:center"> <img style="vertical-align:middle" width="500" height="180" src="./images/wrench_logo.png" /> </h1>

made-with-python Maintenance license repo size Total lines visitors GitHub stars GitHub forks Arxiv

🔧 New


  1. Add Hyper label model, please find more details in our paper.


  1. Add WS explainer, please find more details in our paper.


  1. We have updated the setup.py to make installation more flexible.

Please use pip install ws-benchmark==1.1.2rc0 to install the latest version. We strongly suggest create a new environment to install wrench. We will bring better compatibility in the next stable release. If you have any problems with installation, please let us know.

Known incompatibilities:

tensorflow==2.8.0, albumentations==0.1.12


  1. Wrench is available on ws-benchmark now, using pip install ws-benchmark to qucik install.


  1. Add script to generate LFs for any tabular dataset as well as 5 new tabular datasets, namely, mushroom, spambase, PhishingWebsites, Bioresponse, and bank-marketing.


  1. (beta) Add parallel_fit for torch model to support pytorch DistributedDataParallel-example


  1. A branch of new methods: WeaSEL, ImplyLoss, ASTRA, MeanTeacher, Meta-Weight-Net, Learning-to-Reweight
  2. Support image classification (dataset class / torchvision backbone) as well as DomainNet/Animals-with-Attributes2 datasets (check out the datasets folder)

🔧 What is it?

Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development and evaluation of your own weak supervision models within the benchmark.

For more information, checkout our publications:

If you find this repository helpful, feel free to cite our publication:

title={{WRENCH}: A Comprehensive Benchmark for Weak Supervision},
author={Jieyu Zhang and Yue Yu and Yinghao Li and Yujing Wang and Yaming Yang and Mao Yang and Alexander Ratner},
booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},

🔧 What is weak supervision?

Weak Supervision is a paradigm for automated training data creation without manual annotations.

For a brief overview, please check out this blog.

For more context, please check out this survey.

To track recent advances in weak supervision, please follow this repo.

🔧 Installation

[1] Install anaconda: Instructions here: https://www.anaconda.com/download/

[2] Clone the repository:

git clone https://github.com/JieyuZ2/wrench.git
cd wrench

[3] Create virtual environment:

conda env create -f environment.yml
source activate wrench

If this not working or you want to use only a subset of modules of Wrench, check out this wiki page

[4] Download datasets:

from huggingface_hub import snapshot_download
path = "path to local dir"
snapshot_download(repo_id="jieyuz2/WRENCH", repo_type="dataset", local_dir=path)

🔧 Available Datasets

Note that some datasets may have more training examples than what is reported in README/paper because we include the dev set, whose indices can be found in labeled_id.json if exists.

A documentation of dataset format and usage can be found in this wiki-page


NameTask# class# LF# train# validation# testdata sourceLF source
Censusincome classification28310083556116281linklink
Youtubespam classification2101586120250linklink
SMSspam classification2734571500500linklink
IMDBsentiment classification282000025002500linklink
Yelpsentiment classification283040038003800linklink
AGNewstopic classification49960001200012000linklink
TRECquestion classification6684965500500linklink
Spouserelation classification292225428012701linklink
SemEvalrelation classification91641749178600linklink
CDRbio relation classification23384309204673linklink
Chemprotchemical relation classification10261286116071607linklink
Commercialvideo frame classification246413094797496linklink
Tennis Rallyvideo frame classification2669597461098linklink
Basketballvideo frame classification241797010641222linklink
DomainNetimage classification-----linklink

sequence tagging:

Name# class# LF# train# validation# testdata sourceLF source
OntoNotes 5.01817115812500022897linklink

The detailed documentation is coming soon.

🔧 Available Models

If you find any of the implementations is wrong/problematic, don't hesitate to raise issue/pull request, we really appreciate it!

TODO-list: check this out!


ModelModel TypeReferenceLink to Wrench
Majority VotingLabel Model--link
Weighted Majority VotingLabel Model--link
Dawid-SkeneLabel Modellinklink
Data ProgammingLabel Modellinklink
MeTaLLabel Modellinklink
FlyingSquidLabel Modellinklink
EBCCLabel Modellinklink
IBCCLabel Modellinklink
FABLELabel Modellinklink
Hyper Label ModelLabel Modellinklink
Logistic RegressionEnd Model--link
MLPEnd Model--link
BERTEnd Modellinklink
COSINEEnd Modellinklink
ARS2End Modellinklink
DenoiseJoint Modellinklink
WeaSELJoint Modellinklink
SepLLJoint Modellinklink

sequence tagging:

ModelModel TypeReferenceLink to Wrench
Hidden Markov ModelLabel Modellinklink
Conditional Hidden Markov ModelLabel Modellinklink
LSTM-CNNs-CRFEnd Modellinklink
BERT-CRFEnd Modellinklink
LSTM-ConNetJoint Modellinklink
BERT-ConNetJoint Modellinklink

classification-to-sequence-tagging wrapper:

Wrench also provides a SeqLabelModelWrapper that adaptes label model for classification task to sequence tagging task.

methods from related domains:

Robust Learning methods as end model:

ModelModel TypeReferenceLink to Wrench
Meta-Weight-NetEnd Modellinklink
Learning2ReWeightEnd Modellinklink

Semi-Supervised Learning methods as end model:

ModelModel TypeReferenceLink to Wrench
MeanTeacherEnd Modellinklink

Weak Supervision with cleaned labels (Semi-Weak Supervision):

ModelModel TypeReferenceLink to Wrench
ImplyLossJoint Modellinklink
ASTRAJoint Modellinklink

🔧 Quick examples

🔧 Label model with parallel grid search for hyper-parameters

import logging
import numpy as np
import pprint

from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.search import grid_search
from wrench import labelmodel 
from wrench.evaluation import AverageMeter

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)

#### Specify the hyper-parameter search space for grid search
search_space = {
    'Snorkel': {
        'lr': np.logspace(-5, -1, num=5, base=10),
        'l2': np.logspace(-5, -1, num=5, base=10),
        'n_epochs': [5, 10, 50, 100, 200],

#### Initialize label model
label_model_name = 'Snorkel'
label_model = getattr(labelmodel, label_model_name)

#### Search best hyper-parameters using validation set in parallel
n_trials = 100
n_repeats = 5
target = 'acc'
searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
                             metric=target, direction='auto', search_space=search_space[label_model_name],
                             n_repeats=n_repeats, n_trials=n_trials, parallel=True)

#### Evaluate the label model with searched hyper-parameters and average meter
meter = AverageMeter(names=[target])
for i in range(n_repeats):
    model = label_model(**searched_paras)
    history = model.fit(dataset_train=train_data, dataset_valid=valid_data)
    metric_value = model.test(test_data, target)

metrics = meter.get_results()

For detailed guidance of grid_search, please check out this wiki page.

🔧 Run a standard supervised learning pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.endmodel import MLPModel

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)

#### Train a MLP classifier
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, device=device, metric=target, 
                    patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

🔧 Build a two-stage weak supervision pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.endmodel import MLPModel
from wrench.labelmodel import MajorityVoting

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)

#### Generate soft training label via a label model
#### The weak labels provided by supervision sources are alreadly encoded in dataset object
label_model = MajorityVoting()
label_model.fit(train_data, valid_data)
soft_label = label_model.predict_proba(train_data)

#### Train a MLP classifier with soft label
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=soft_label, 
                    device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

#### We can also train a MLP classifier with hard label
from snorkel.utils import probs_to_preds
hard_label = probs_to_preds(soft_label)
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=hard_label, 
          device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

🔧 Procedural labeling function generator

import logging
import torch

from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.synthetic import ConditionalIndependentGenerator, NGramLFGenerator
from wrench.labelmodel import FlyingSquid

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
logger = logging.getLogger(__name__)

#### Generate synthetic dataset
generator = ConditionalIndependentGenerator(
    alpha=0.75, # mean accuracy
    beta=0.1, # mean propensity
    alpha_radius=0.2, # radius of accuracy
    beta_radius=0.1 # radius of propensity
train_data = generator.generate_split('train', 10000)
valid_data = generator.generate_split('valid', 1000)
test_data = generator.generate_split('test', 1000)

#### Evaluate label model on synthetic dataset
label_model = FlyingSquid()
label_model.fit(dataset_train=train_data, dataset_valid=valid_data)
target_value = label_model.test(test_data, metric_fn='auc')

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Load real-world dataset
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)

#### Generate procedural labeling functions
generator = NGramLFGenerator(dataset=train_data, min_acc_gain=0.1, min_support=0.01, ngram_range=(1, 2))
applier = generator.generate(mode='correlated', n_lfs=10)
L_test = applier.apply(test_data)
L_train = applier.apply(train_data)

#### Evaluate label model on real-world dataset with semi-synthetic labeling functions
label_model = FlyingSquid()
label_model.fit(dataset_train=L_train, dataset_valid=valid_data)
target_value = label_model.test(L_test, metric_fn='auc')

🔧 Contact

Contact person: Jieyu Zhang, jieyuzhang97@gmail.com

Don't hesitate to send us an e-mail if you have any question.

We're also open to any collaboration!

🔧 Contributing Dataset and Model

We sincerely welcome any contribution to the datasets or models!

🔧 Citattion

title={{WRENCH}: A Comprehensive Benchmark for Weak Supervision},
author={Jieyu Zhang and Yue Yu and Yinghao Li and Yujing Wang and Yaming Yang and Mao Yang and Alexander Ratner},
booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},