Home

Awesome

Probabilistic Transformer

<div align="center"> <img width="200" src="https://github.com/whyNLP/Probabilistic-Transformer/assets/43395692/d0e75012-f52c-470e-adc2-453b72b71fab" /> <p> A probabilistic dependency model <br> shares a similar computation graph with transformers <br> That is the Probabilistic Transformer </p> </div>

The code base for project Probabilistic Transformer, a model of contextual word representation from a syntactic and probabilistic perspective. The paper "Probabilistic Transformer: A Probabilistic Dependency Model for Contextual Word Representation" was accepted to ACL2023 Findings.

<details> <summary>The Map of AI Approaches</summary> <div align="center"> <img width="400" src="https://github.com/whyNLP/Probabilistic-Transformer/assets/43395692/55df3127-ea9d-4005-9e29-c2d836850f50" /> </div> </details>

Warning
In this git branch, the codes are developed in a way that is easy to integrate with all kinds of modules, but not well-optimized for speed. The repo structure is a bit messy and the framework it uses (flair) is outdated.

Preparation

Code Environment

To prepare the code environment, use

cd src
pip install -r requirements.txt

Due to package compatibility, it will install pytorch with version 1.7.1. Feel free to upgrade it with the command:

pip install torch==1.10.2

Or this command:

pip install --upgrade torch

This work is developed under torch==1.10.2.

Dataset

Our code will automatically download the dataset if it finds the dataset you want to use is missing. Some datasets require license/purchase, and the code would throw an error telling you where to download the dataset. We also provide detailed instructions in the template config file and doc strings.

How to run

Training

Simply run the following commands:

cd src
python train.py

By default, it will use the config file src/config.toml. To use other config files, use -c or --config to specify the configuration file:

cd src
python train.py -c ../path/to/config.toml

Prediction / Inference

To do inference on a sentence, run python predict.py. The usage is exactly the same with training, just use a config that has just been used for training before. To modify the sentence for inference, please modify codes in predict.py.

Drawing Dependency Parse Trees

To visualize the dependency parse trees produced by our models, run python draw.py. The usage is the same as inference. It will generate dep_graph.tex in your working directory. You may compile the latex file and get the figures in PDF.

There are 3 options at the top of the file draw.py:

If we take the attention scores in transformers as the dependency edge scores, then we may also draw dependency parse trees from transformers.

Evaluation for Unsupervised Dependency Parsing Task

To do unsupervised dependency parsing, run python evaluate.py. The usage is the same as drawing. It will print the UAS (Unlabeled Attachment Score) to the console.

There are 4 options at the top of the file draw.py:

Result

We provide the config files in configs/best. To reproduce the results, please use the following command

cd src
python train.py -c ../configs/best/<CONFIG_FILE>

where <CONFIG_FILE> should be replaced by the config file in the tables below.

Note
Part of the results presented below was not contained in our paper.

Probabilistic Transformer

TaskDatasetMetricConfigPerformance (avg. 5 runs)# of ParametersSpeed (Sample/sec)Total Time
MLMPTBPerplexitycrf-mlm-ptb.toml62.86 $\pm$ 0.406291456173.9515:26:53
MLMBLLIP-XSPerplexitycrf-mlm-bllip.toml123.18 $\pm$ 1.506291456172.0120:30:13
POSPTBAccuracycrf-pos-ptb.toml96.29 $\pm$ 0.033145728222.915:13:42
POSUDAccuracycrf-pos-ud.toml90.96 $\pm$ 0.102359296385.841:02:42
UPOSUDAccuracycrf-upos-ud.toml91.57 $\pm$ 0.124194304205.831:47:38
NERCONLL03F1crf-ner-conll03.toml75.47 $\pm$ 0.359437184202.842:45:25
CLSSST-2Accuracycrf-cls-sst2.toml82.04 $\pm$ 0.8810485760675.781:54:03
CLSSST-5Accuracycrf-cls-sst5.toml42.77 $\pm$ 1.182630656185.331:13:36
SYNCOGSAccuracycrf-syn-cogs.toml84.60 $\pm$ 2.06147456507.662:14:25
SYNCFQ-mcd1EM / LAScrf-syn-cfq-mcd1.toml78.88 $\pm$ 2.81 / 97.84 $\pm$ 0.331114112234.1319:04:35
SYNCFQ-mcd2EM / LAScrf-syn-cfq-mcd2.toml48.41 $\pm$ 4.99 / 91.91 $\pm$ 0.681114112225.7519:22:46
SYNCFQ-mcd3EM / LAScrf-syn-cfq-mcd3.toml45.68 $\pm$ 4.17 / 90.87 $\pm$ 0.701114112269.9614:26:53

Transformer

TaskDatasetMetricConfigPerformance (avg. 5 runs)# of ParametersSpeed (Sample/sec)Total Time
MLMPTBPerplexitytransformer-mlm-ptb.toml58.43 $\pm$ 0.5823809408434.906:27:05
MLMBLLIP-XSPerplexitytransformer-mlm-bllip.toml101.91 $\pm$ 1.4011678720616.847:10:23
POSPTBAccuracytransformer-pos-ptb.toml96.44 $\pm$ 0.0415358464527.462:11:05
POSUDAccuracytransformer-pos-ud.toml91.17 $\pm$ 0.113155456554.100:39:34
UPOSUDAccuracytransformer-upos-ud.toml91.96 $\pm$ 0.0614368256696.490:31:52
NERCONLL03F1transformer-ner-conll03.toml74.02 $\pm$ 1.111709312577.570:49:38
CLSSST-2Accuracytransformer-cls-sst2.toml82.51 $\pm$ 0.2623214080713.342:03:30
CLSSST-5Accuracytransformer-cls-sst5.toml40.13 $\pm$ 1.098460800871.610:17:42
SYNCOGSAccuracytransformer-syn-cogs.toml82.05 $\pm$ 2.18100000856.281:16:25
SYNCFQ-mcd1EM / LAStransformer-syn-cfq-mcd1.toml92.35 $\pm$ 2.37 / 99.21 $\pm$ 0.301189728618.957:33:43
SYNCFQ-mcd2EM / LAStransformer-syn-cfq-mcd2.toml80.34 $\pm$ 1.40 / 96.24 $\pm$ 0.681189728590.358:15:08
SYNCFQ-mcd3EM / LAStransformer-syn-cfq-mcd3.toml73.43 $\pm$ 6.07 / 94.85 $\pm$ 0.931189728601.138:29:28

Universal Transformer

TaskDatasetMetricConfigPerformance (avg. 5 runs)# of ParametersSpeed (Sample/sec)Total Time
SYNCOGSAccuracyuniversal-transformer-syn-cogs.toml80.50 $\pm$ 3.49500001008.651:15:29
SYNCFQ-mcd1EM / LASuniversal-transformer-syn-cfq-mcd1.toml95.48 $\pm$ 2.09 / 99.59 $\pm$ 0.19198288603.018:20:50
SYNCFQ-mcd2EM / LASuniversal-transformer-syn-cfq-mcd2.toml78.63 $\pm$ 3.54 / 95.62 $\pm$ 0.75198288626.539:07:15
SYNCFQ-mcd3EM / LASuniversal-transformer-syn-cfq-mcd3.toml71.49 $\pm$ 5.39 / 94.57 $\pm$ 1.25198288603.238:17:17

<sub><i>* "Universal Transformer" only means weight sharing between layers in transformers. See details in Ontanón et al. (2021).</i></sub>
<sub><i>** The training speed and time are for reference only. The speed data is randomly picked during the training and the product of speed and time is not equal to the number of samples.</i></sub>
<sub><i>*** The random seeds for the 5 runs are: 0, 1, 2, 3, 4.</i></sub>

Questions

  1. I am working on a cluster where the compute node does not have Internet, so I cannot download the dataset before training. What should I do?

That is simple. Go to src/train.py and add exit(0) before training (line 105). Execute the training command in the login node (where you have access to the Internet). It will download the dataset without training the model. Finally, remove the line of code you added and train the model in the compute node.

  1. Which type of positional encoding do you use for transformers?

We use absolute positional encoding for transformers in our experiments. Though the computation graph of probabilistic transformers is closer to that of transformers with relative positional encoding, we empirically find that positional encoding hardly makes any difference to the performance of transformers.

  1. Why not test on the GLUE dataset?

GLUE is a standard benchmark for language understanding, and most recent works with strong pre-trained word representations choose to test their models on this dataset. Our work does not involve pre-training, which indicates a weak ability for language understanding. To better evaluate the ability of word representation for our model, we think it might be more suitable to compare our model with a vanilla transformer on MLM and POS tagging tasks than GLUE.

  1. How strong is your baseline?

To make sure our baseline (transformer) implementation is strong enough, part of our experiments use the same setting as previous works:

<details> <summary>Details for Baseline Compariason</summary>
TaskDatasetMetricSourcePerformance
MLMPTBPerplexityTransformer, Shen et al. (2021)64.05
MLMPTBPerplexityStructformer, Shen et al. (2021)60.94
MLMPTBPerplexityTransformer, Ours58.43
SYNCOGSAccuracyUniversal Transformer, Ontanón et al. (2021)78.4
SYNCOGSAccuracyTransformer, Ours82.05
SYNCOGSAccuracyUniversal Transformer, Ours80.50
SYNCFQ-mcd1EM / LASTransformer, Bergen et al. (2021)75.3 $\pm$ 1.7 / 97.0 $\pm$ 0.1
SYNCFQ-mcd1EM / LASTransformer, Ours92.35 $\pm$ 2.37 / 99.21 $\pm$ 0.30
SYNCFQ-mcd1EM / LASUniversal Transformer, Bergen et al. (2021)80.1 $\pm$ 1.7 / 97.8 $\pm$ 0.2
SYNCFQ-mcd1EM / LASUniversal Transformer, Ours95.48 $\pm$ 2.09 / 99.59 $\pm$ 0.19
SYNCFQ-mcd2EM / LASTransformer, Bergen et al. (2021)59.3 $\pm$ 2.7 / 91.8 $\pm$ 0.4
SYNCFQ-mcd2EM / LASTransformer, Ours80.34 $\pm$ 1.40 / 96.24 $\pm$ 0.68
SYNCFQ-mcd2EM / LASUniversal Transformer, Bergen et al. (2021)68.6 $\pm$ 2.3 / 92.5 $\pm$ 0.4
SYNCFQ-mcd2EM / LASUniversal Transformer, Ours78.63 $\pm$ 3.54 / 95.62 $\pm$ 0.75
SYNCFQ-mcd3EM / LASTransformer, Bergen et al. (2021)48.0 $\pm$ 1.6 / 89.4 $\pm$ 0.3
SYNCFQ-mcd3EM / LASTransformer, Ours73.43 $\pm$ 6.07 / 94.85 $\pm$ 0.93
SYNCFQ-mcd3EM / LASUniversal Transformer, Bergen et al. (2021)59.4 $\pm$ 2.0 / 90.5 $\pm$ 0.5
SYNCFQ-mcd3EM / LASUniversal Transformer, Ours71.49 $\pm$ 5.39 / 94.57 $\pm$ 1.25
</details>
  1. I have trouble understanding / running the code. Could you help me with it?

Sure. Welcome to create an issue or email me at wuhy1@shanghaitech.edu.cn.