Home

Awesome

DataMUX

PyTorch implementation for the paper:

DataMUX: Data Multiplexing for Neural Networks
Vishvak Murahari, Carlos E. Jimenez, Runzhe Yang, Karthik Narasimhan

models

This repository contains code for reproducing results. We provide pretrained model weights and associated configs to run inference or train these models from scratch. If you find this work useful in your research, please cite:

@inproceedings{
murahari2022datamux,
title={Data{MUX}: Data Multiplexing for Neural Networks},
author={Vishvak Murahari and Carlos E Jimenez and Runzhe Yang and Karthik R Narasimhan},
booktitle={Thirty-Sixth Conference on Neural Information Processing Systems},
year={2022},
url={https://openreview.net/forum?id=UdgtTVTdswg}
}

Table of Contents

Setup and Dependencies

Our code is implemented in PyTorch. To setup, do the following:

  1. Install Python 3.6
  2. Get the source:
git clone https://github.com/princeton-nlp/DataMUX.git datamux
  1. Install requirements into the datamux virtual environment, using Anaconda:
conda env create -f env.yaml

Usage

Overview

For sentence-level classification tasks, refer to run_glue.py and run_glue.sh. For token-level classification tasks, refer to run_ner.py and run_ner.sh.

Pre-trained checkpoints

We release all the pretrained checkpoints on the Hugging Face model hub. We list the checkpoints below. For number of instances, use 2, 5, 10, 20 or 40.

TaskModel name on hubFull path
Retrieval Warmupdatamux-retrieval-<num_instances>princeton-nlp/datamux-retrieval-<num_instances>
MNLIdatamux-mnli-<num_instances>princeton-nlp/datamux-mnli-<num_instances>
QNLIdatamux-qnli-<num_instances>princeton-nlp/datamux-qnli-<num_instances>
QQPdatamux-qqp-<num_instances>princeton-nlp/datamux-qqp-<num_instances>
SST2datamux-sst2-<num_instances>princeton-nlp/datamux-sst2-<num_instances>
NERdatamux-ner-<num_instances>princeton-nlp/datamux-ner-<num_instances>

Settings

The bash scripts run_ner.sh and run_glue.sh take the following arguments:

ArgumentFlagExplanationArgument Choices
NUM_INSTANCES-N --num_instancesNumber of multiplexing instances2,5,10,20,40
DEMUXING-d --demuxingDemultiplexing architecture"index", "mlp"
MUXING-m --muxingMultiplexing architecture"gaussian_hadamard", "binary_hadamard", "random_ortho"
SETTING-s --settingTraining setting"baseline", "finetuning", "retrieval_pretraining"
TASK_NAME--taskTask name during finetuning"mnli", "qnli", "sst2", "qqp" for run_glue.py or "ner" for run_ner.py
LEARNING_RATE--lrLearning rate for optimizationAny float but we use either 2e-5 or 5e-5
BATCH_SIZE--batch_sizeBatch size (after multiplexing); note that the effective batch size is BATCH_SIZE * NUM_INSTANCESAny integer. If left unset, will be set automatically based on value of N
CONFIG_NAME--config_nameConfig path for backbone Transformer ModelAny config file in configs directory
MODEL_PATH--model_pathModel path if either continuing to train from a checkpoint or initialize from retrieval task pretrained checkpointPath to local checkpoint or path to model on the hub
LEARN_MUXING--learn_muxingWhether to learn instance embeddings in multiplexing
DO_TRAIN--do_trainPass flag to do training
DO_EVAL--do_evalPass flag to do eval

Below we list exemplar commands for different training settings:

Retrieval pretraining

This commands runs retrieval pretraining for N=2

sh run_glue.sh \
   -N 2 \
   -d index \
   -m gaussian_hadamard \
   -s retrieval_pretraining \
   --config_name configs/ablations/base_model/roberta.json \
   --lr 5e-5 \
   --do_train \
   --do_eval

Finetuning

This command finetunes from a retrieval pretrained checkpoint with N=2

sh run_glue.sh \
   -N 2 \
   -d index \
   -m gaussian_hadamard \
   -s finetuning \
   --config_name configs/ablations/base_model/roberta.json \
   --lr 5e-5 \
   --task mnli \
   --model_path princeton-nlp/datamux-retrieval-2 \
   --do_train \
   --do_eval

Similar, to run token-level classification tasks like NER, change run_glue.sh to run_ner.sh

sh run_ner.sh \
   -N 2 \
   -d index \
   -m gaussian_hadamard \
   -s finetuning \
   --config_name configs/ablations/base_model/roberta.json \
   --lr 5e-5 \
   --task ner \
   --model_path princeton-nlp/datamux-retrieval-2 \
   --do_train \
   --do_eval 

Baselines

For the non-multiplexed baselines, run the following commnands

sh run_glue.sh \
-N 1 \
-s baseline \
--config_name configs/ablations/base_model/roberta.json \
--lr 2e-5 \
--task mnli

Vision

For reproducing results on the vision tasks for MLPs and CNNs, please use this notebook

Reference

@inproceedings{
murahari2022datamux,
title={Data{MUX}: Data Multiplexing for Neural Networks},
author={Vishvak Murahari and Carlos E Jimenez and Runzhe Yang and Karthik R Narasimhan},
booktitle={Thirty-Sixth Conference on Neural Information Processing Systems},
year={2022},
url={https://openreview.net/forum?id=UdgtTVTdswg}
}

License

Check LICENSE.md