Home

Awesome

RegaVAE

This is the official repo for our paper:

RegaVAE: A Retrieval-Augmented Gaussian Mixture Variational Auto-Encoder for Language Modeling

Model Architecture

Architecture of RegaVAE. Based on the training data, we first train a VAE to construct a compact latent space, which ensures that the latent variable z contains both current and future information (see § 3.1 of the paper). We then build a retrieval database and then aggregate the retrieved information into the generator (see § 3.2 of the paper). VAE Encoder and Decoder parameters are the same in all steps. In order to ensure fairness, the Corpus data and the Source data in the training set are the same. $G$ represents the Gaussian mixture distribution, and $π$ is the corresponding parameter.

Datasets

Download three dataset from this link. Unzip them and put them under the data directory.

Step1

Firstly,

cd Step1

Training

For Yelp dataset,

python main.py --train_file ../data/yelp/yelp.train.txt \
--valid_file ../data/yelp/yelp.valid.txt \
--per_gpu_train_batch_size 4 \
--cycle_annealing

For Yahoo dataset,

python main.py --train_file ../data/yahoo/yahoo.train.txt \
--valid_file ../data/yahoo/yahoo.valid.txt \
--per_gpu_train_batch_size 4 \
--cycle_annealing

For WP dataset,

python main.py --train_source_path ../data/writingPrompts/train.wp_source \
--train_target_path ../data/writingPrompts/train.wp_target \
--valid_source_path ../data/writingPrompts/valid.wp_source \
--valid_target_path ../data/writingPrompts/valid.wp_target \
--dataset_type wp \
--per_gpu_train_batch_size 4 \
--cycle_annealing

The above are only the best adjusted hyperparameters. You can get a better Step1 model by passing other parameters. The model we trained is available at this link.

Step2

Firstly,

cd Step2

Step2 here corresponds to Step2 and Step3 in the figure. Before training, please rename the model trained in Step 1 to model_epoch_-1.pth and add it to the model generation path. In addition, please download the file in this link to the Step2 folder.

Training

For Yelp dataset,

python main.py --train_file ../data/yelp/yelp.train.txt \
--valid_file ../data/yelp/yelp.valid.txt \
--per_gpu_train_batch_size 4 \
--load_epoch -1 \
--cycle_annealing

For Yahoo dataset,

python main.py --train_file ../data/yahoo/yahoo.train.txt \
--valid_file ../data/yahoo/yahoo.valid.txt \
--per_gpu_train_batch_size 4 \
--load_epoch -1 \
--cycle_annealing

Test

For Yelp dataset,

python main.py --train_file ../data/yelp/yelp.train.txt \
--valid_file ../data/yelp/yelp.valid.txt \
--per_gpu_train_batch_size 4 \
--load_epoch -1 \
--cycle_annealing \
--eval \
--eval_metrics

For Yahoo dataset,

python main.py --train_file ../data/yahoo/yahoo.train.txt \
--valid_file ../data/yahoo/yahoo.valid.txt \
--per_gpu_train_batch_size 4 \
--load_epoch -1 \
--cycle_annealing \
--eval \
--eval_metrics

Generation

For Yelp dataset,

python main.py --generation \
--test_file ../data/yelp/yelp.test.txt \
 --load_epoch -1 \
--top_k 50 \
--top_p 0.9