Awesome
<!-- # Patton<img src="figure/patton.svg" width="30" height="30" />: Language Model Pretraining on Text-rich Networks -->Patton <img src="figure/patton.svg" width="30" height="30" />
This repository contains the source code and datasets for Patton<img src="figure/patton.svg" width="15" height="15" />: Language Model Pretraining on Text-rich Networks, published in ACL 2023.
Links
Requirements
The code is written in Python 3.8. Before running, you need to first install the required packages by typing following commands (Using a virtual environment is recommended):
pip3 install -r requirements.txt
Overview
Patton is a framework to pretrain language models on text-rich networks, with two strategies: network-contextualized masked language modeling and masked node prediction.
<p align="center"> <img src="figure/Patton.png" width="600px"/> </p>Datasets
Download processed data. To reproduce the results in our paper, you need to first download the processed datasets. The extract the data files by
tar -xf data.tar.gz
Create a new ckpt/
folder for checkpoint saving and a new logs/
folder for logs saving.
mkdir ckpt
mkdir logs
Raw data & data processing. Raw data can be downloaded from MAG and Amazon directly. You can also find our data processing codes here. They might be useful if you want to obtain processed dataset (both for pretrain and finetune) for other networks in MAG and Amazon.
Use your own dataset. To pretrain Patton on your own data, you need to prepare the pretraining files: train.tsv, val.tsv, test.tsv. In the three files, each row represents a linked node pair:
{
"q_text": (str) node_1 associated text,
"k_text": (str) node_2 associated text,
"q_n_text": (List(str)) node_1 neighbors' associated text,
"k_n_text": (List(str)) node_2 neighbors' associated text,
}
Please refer to the file in our processed dataset for their detailed format information.
We also provide pre-tokenization code here to improve pretraining/finetuning efficiency.
Pretraining Patton
Pretraining Patton starting from bert-base-uncased.
bash run_pretrain.sh
Pretraining SciPatton starting from scibert-base-uncased.
bash run_pretrain_sci.sh
Change $PROJ_DIR
to your project directory. We support both single GPU training and multi-GPU training.
You can directly download our pretrained checkpoints here. Then extract the checkpoint files by
tar -xf pretrained_ckpt.tar.gz
Finetuning Patton
Classification
Run classification train.
bash nc_class_train.sh
Run classification test.
bash nc_class_test.sh
Change $STEP
to the highest validation set performance step.
Retrieval
Run bm25 to prepare hard negatives.
cd bm25/
bash bm25.sh
Prepare data for retrieval.
cd src/
bash nc_retrieve_gen_bm25neg.sh
bash build_train.sh
Run retrieval train.
bash nc_retrieve_train.sh
Run retrieval test.
bash nc_infer.sh
bash nc_retrieval.sh
Reranking
Prepare data for reranking.
bash scripts/match.sh
Run reranking train.
bash nc_rerank_train.sh
Run reranking test.
bash nc_rerank_test.sh
Link Prediction
Run link prediction train.
bash lp_train.sh
Run link prediction test.
bash lp_test.sh
Citations
Please cite the following paper if you find the code helpful for your research.
@inproceedings{jin2023patton,
title={Patton: Language Model Pretraining on Text-Rich Networks},
author={Jin, Bowen and Zhang, Wentao and Zhang, Yu and Meng, Yu and Zhang, Xinyang and Zhu, Qi and Han, Jiawei},
booktitle={Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics},
year={2023}
}
Acknowledge
Some parts of our code are adapted from the tevatron repository. Huge thanks to the contributors of the amazing repository!
Code base Structure
$CODE_DIR
├── ckpt
├── data
│ ├── amazon
│ │ ├── cloth
│ │ ├── home
│ │ └── sports
│ └── MAG
│ ├── CS
│ ├── Geology
│ └── Mathematics
├── src
│ ├── OpenLP
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── arguments.py
│ │ ├── dataset
│ │ ├── driver
│ │ ├── loss.py
│ │ ├── models
│ │ ├── modeling.py
│ │ ├── retriever
│ │ ├── trainer
│ │ └── utils.py
│ └── scripts
│ ├── build_train.py
│ ├── build_train_ncc.py
│ ├── build_train_neg.py
│ └── bm25_neg.py
└── logs