Home

Awesome

A Reparameterized Discrete Diffusion Model for Text Generation

This repository contains the official implementation of paper A Reparameterized Discrete Diffusion Model for Text Generation.

Dependencies

The codebase is implemented with FairSeq. To install the dependencies, run (recommended in a virtual environment) the following commands:

pip install -r requirements.txt

# install our package of discrete diffusion models
pip install -e discrete_diffusion

# install our fork of fairseq
cd fairseq
python3 setup.py build develop
cd ..

Note The environment is tested with Python 3.8.10, PyTorch 1.10.0/1.12.0, and CUDA 11.3. Also note our fork of fairseq modifies several files in the original codebase; using more recent versions of fairseq might lead to unexpected dependency conflicts.

Basic Usage of the Discrete-diffusion Library

We implement discrete diffusion models in a self-contained library discrete_diffusion for general use. The library provides implementations of various typical discrete diffusion models, consisting of

<details> <summary> click to check the implementation details as well as their arguments 👇 </summary>

These diffusion models share the same set of interfaces allowing for external uses. In particular, they are defined as subclasses of DiscreteDiffusion class, taking the following form:

class DiscreteDiffusion(nn.Module):
    """
    The parent class for discrete denoising diffusion probabilistic models.

    It supports the following methods:
    - q_sample()
        Sample x_t ~ q(x_t | x_0) to construct noisy Transformer inputs.
    - compute_losses()
        Compute the loss L_t = KL(q||p) at t-th time step.
    - sample_step()
        Sample x_t ~ p(x_{t-1} | x_t, x_0) at t-th time step.
    """
    
    def __init__(self, num_timesteps):
        super().__init__()
        self.num_timesteps = num_timesteps

    def q_sample(self, x_0, t, **kwargs):
        """

        Sample from q(x_t | x_0), which is used as the model inputs.

        Args:
            x_0: token ids with shape [B, N]
            t: current time step, tensor with shape [B]

        Returns:
            return a dict of relevant outputs including x_t.
            
        """

    def compute_losses(self, inputs, **kwargs):
        """
        
        Compute the loss objective KL(q||p) to train our generative process.

        Args:
            inputs: a dict that contains input types specific to different diffusion processes, containing
                - x_t: token ids with shape [B, N]
                - t: scalar timesteps, with shape [B]

        Returns:
            possibly return a dict of relevant outputs, including the loss used for training.
            
        """

    def sample_step(self, decoder_out, denoising_fn, **kwargs):
        """
        Given a time step t, start from x_t and sample x_{t-k} from q(x_{t-k} | x_t).
        
        Args:
            decoder_out: a namedtuple that contains decoding info, including
                - x_t: token ids with shape [B, N]
                - t: scalar timesteps
                - max_steps: the maximum number of decoding steps
                - ...
            
            denoising_fn: a function that takes in x_t and t and returns model logits

            kwargs: other arguments that are used to control decoding.
        
        Returns:
            return a new decoder_out namedtuple.
        """

A DiscreteDiffusion model can be instantiated by configuring the following:

See here for a more comprehensive list of arguments.

</details>

Decoding Strategies

Vanilla Sampling Scheme

By passing --decoding-strategy default, the vanilla sampling scheme (specific to each discrete diffusion process) is used.

Improved Sampling with Reparameterization

A more advanced decoding approach can be invoked by passing --decoding-strategy reparam-<conditioning-of-v>-<topk_mode>-<schedule>. This approach is based on the proposed reparameterization in our paper and allows for more effective decoding procedures. The options specify the decoding algorithm via

See the implementation for more details about the options.

Machine Translation

Data Preprocessing

Please see the scripts below for details.

Note

IWSLT14 DE-EN

We follow the standard pre-processing in fairseq/examples to prepare the binarized data:

# fetch and preprocess the data to BPE codes
cd examples/translation/
bash prepare-iwslt14.sh
cd ../..

# binarize the data
TEXT=examples/translation/iwslt14.tokenized.de-en
fairseq-preprocess --joined-dictionary --source-lang de --target-lang en \
    --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
    --destdir data-bin/iwslt14.tokenized.de-en \
    --workers 20

WMT14 EN-DE

We use the data released in fairseq/examples to prepare the dataset:

wget http://dl.fbaipublicfiles.com/nat/original_dataset.zip
unzip original_dataset.zip
TEXT=wmt14_ende
fairseq-preprocess --joined-dictionary \
    --source-lang en --target-lang de \
    --trainpref $TEXT/train.en-de --validpref $TEXT/valid.en-de --testpref $TEXT/test.en-de \
    --destdir data-bin/wmt14_ende --thresholdtgt 0 --thresholdsrc 0 \
    --workers 20

WMT16 EN-RO

For this dataset, we use the raw data wmt16.tar.gz as pre-processed in this repository.

tar xzvf wmt16.tar.gz

TEXT=wmt16/en-ro

# move train/ dev/ test/ bpe codes into the $TEXT folder
mv $TEXT/train/corpus.bpe.en $TEXT/train.bpe.en
mv $TEXT/train/corpus.bpe.ro $TEXT/train.bpe.ro
mv $TEXT/dev/dev.bpe.en $TEXT/dev.bpe.en
mv $TEXT/dev/dev.bpe.ro $TEXT/dev.bpe.ro
mv $TEXT/test/test.bpe.en $TEXT/test.bpe.en
mv $TEXT/test/test.bpe.ro $TEXT/test.bpe.ro

# binarize the data
fairseq-preprocess --joined-dictionary \
    --source-lang en --target-lang ro \
    --trainpref $TEXT/train.bpe --validpref $TEXT/dev.bpe --testpref $TEXT/test.bpe \
    --destdir data-bin/wmt16_enro --thresholdtgt 0 --thresholdsrc 0 \
    --workers 20

Training

We first get into the fairseq folder and then run the following commands to train the models.

######## training scripts for IWSLT'14 , WMT'14, and WMT'16 
# first cd to fairseq
# we use 1 GPU for IWSLT'14, 4 GPUs for WMT'14 and 2 GPUs for WMT'16 datasets respectively.
CUDA_VISIBLE_DEVICES=0 bash experiments/mt_train.sh -m absorbing -d <iwslt/wmt14/wmt16> -s default -e True --store-ema --label-smoothing 0.1
CUDA_VISIBLE_DEVICES=1 bash experiments/mt_train.sh -m multinomial -d <iwslt/wmt14/wmt16> -s default -e True --not-diffusing-special-sym --store-ema --label-smoothing 0.0
CUDA_VISIBLE_DEVICES=2 bash experiments/mt_train.sh -m reparam-absorbing -d <iwslt/wmt14/wmt16> -s default -e True --q-sample-mode coupled  --store-ema --label-smoothing 0.1 --reweighting-type linear
CUDA_VISIBLE_DEVICES=3 bash experiments/mt_train.sh -m reparam-multinomial -d <iwslt/wmt14/wmt16> -s default -e True --not-diffusing-special-sym --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear

Note

Generation & Evaluation

The evaluation pipeline is handled by experiments/mt_generate.sh. The script will generate the translation results and evaluate the BLEU score.

########### IWLS'14, WMT'14, and WMT'16 datasets
# we recommend putting each checkpoint into a separate folder
# since the script will put the decoded results into a file under the same folder of each checkpoint.
CUDA_VISIBLE_DEVICES=0 bash experiments/mt_generate.sh -a false -c <checkpoint_path> -d <iwslt/wmt14/wmt16> 

Arguments:

Trained Model Checkpoints

We also provide the checkpoints of our trained models.

DatasetModelCheckpoint link
IWSLT'14Multinomiallink
IWSLT'14Absorbinglink
IWSLT'14Reparam-multinomiallink
IWSLT'14Reparam-absorbinglink
WMT'14Multinomiallink
WMT'14Absorbinglink
WMT'14Reparam-multinomiallink
WMT'14Reparam-absorbinglink
WMT'16Multinomiallink
WMT'16Absorbinglink
WMT'16Reparam-multinomiallink
WMT'16Reparam-absorbinglink

Question Generation and Paraphrasing Tasks

We follow the experimental setup in DiffuSeq for question generation and paraphrasing tasks .

Data Preprocessing

The raw data of these two tasks can be fetched from the original DiffuSeq repository. We then binarize the data via the provided script.

# put the raw data in the directory ``diffuseq_data/QG``
# Preprocess the question generation dataset
bash diffusion_mt/scripts/preprocess_diffuseq_datasets.sh QG

# put the raw data in the directory ``diffuseq_data/QQP``
# Preprocess the paraphrasing dataset
bash diffusion_mt/scripts/preprocess_diffuseq_datasets.sh QQP

Training

# QQP or QG datasets
# first cd to fairseq
CUDA_VISIBLE_DEVICES=0,1 bash experiments/diffuseq_train.sh -m absorbing -d <qqp/qg> -s default -e True --store-ema --label-smoothing 0.1
CUDA_VISIBLE_DEVICES=2,3 bash experiments/diffuseq_train.sh -m multinomial -d <qqp/qg> -s default -e True      --not-diffusing-special-sym --store-ema --label-smoothing 0.0 
CUDA_VISIBLE_DEVICES=0,1 bash experiments/diffuseq_train.sh -m reparam-multinomial -d <qqp/qg> -s default -e True  --not-diffusing-special-sym  --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
CUDA_VISIBLE_DEVICES=2,3 bash experiments/diffuseq_train.sh -m reparam-absorbing -d <qqp/qg> -s default -e True      --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear 

Generation & Evaluation

We closely follow the generation & evaluation protocols as in DiffuSeq to ensure a head-to-head comparison. The whole pipeline is re-implemented in fairseq/diffusion_mt/scripts/decode_diffuseq.py and fairseq/diffusion_mt/scripts/eval_diffuseq.py respectively to be compatible with Fairseq. Run the following commands:

# we recommend putting each checkpoint into a separate folder
# since the script will put the decoded results into a file under the same folder of each checkpoint.
CUDA_VISIBLE_DEVICES=0 bash experiments/diffuseq_generate.sh -a false -b true -c <checkpoint_path> -d <qqp/qg> 

Arguments:

Trained Model Checkpoints

We also provide the checkpoints of our trained models.

DatasetModelCheckpoint link
QGMultinomiallink
QGAbsorbinglink
QGReparam-multinomiallink
QGReparam-absorbinglink
QQPMultinomiallink
QQPAbsorbinglink
QQPReparam-multinomiallink
QQPReparam-absorbinglink

Citation

@article{zheng2023rdm,
  title={A Reparameterized Discrete Diffusion Model for Text Generation},
  author={Zheng, Lin and Yuan, Jianbo and Yu, Lei and Kong, Lingpeng},
  journal={arXiv preprint arXiv:2302.05737},
  year={2023}
}