Home

Awesome

Beyond Autoregression: Fast LLMs via Self-Distillation Through Time

By Justin Deschenaux and Caglar Gulcehre.

arXiv deploy deploy

main_graphics

Summary

How to run the code?

Install our code

git clone https://github.com/jdeschena/sdtt.git
pushd sdtt
pip install -r requirements.txt
pip install flash-attn
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -e .
popd

Try our pre-trained models

Load the original MDLM (small) weights

from sdtt import load_mdlm_small
mldm_small = load_mdlm_small() 

Load the baseline students

from sdtt import load_small_student
student = load_small_student(loss="kld", round=7)  # load the kld student after the last distillation round
student = load_small_student(loss="mse", round=2)  # load the mse student after the second distillation round
student = load_small_student(loss="tvd", round=1)  # load the tvd student after the first distillation round

Load the students from the scaling experiment

from sdtt import load_scaling_student
student = load_scaling_student(size="sm", round=7)  # load small student after the last distillation round
student = load_scaling_student(size="md", round=1)   # load medium student after the first distillation round
student = load_scaling_student(size="large", round=3)  # load large student after the third distillation round

Load the teachers from the scaling experiment

from sdtt import load_scaling_teacher
student = load_scaling_student(size="sm",)  # load small teacher
student = load_scaling_student(size="md",)   # load medium teacher
student = load_scaling_student(size="large",)  # load large teacher

Sample from the pretrained models

from sdtt import load_small_student, load_scaling_student, load_scaling_teacher
import torch

model = load_small_student(loss="kld", round=7)  # load model, see above
model.cuda()  # put model on gpu

# Unconditional generation
tokens = model.sample(
    n_samples=8,
    num_steps=256,
    seq_len=1024,
    verbose=True,
)
# Detokenize
uncond_text = model.tokenizer.batch_decode(tokens)

# Conditional generation, based on a prompt
# Prepare a prompt
prompt = "Today is a great day. The sun is shining,"
prompt_tokens = model.tokenizer(prompt)["input_ids"]
prompt_tokens.insert(0, model.tokenizer.bos_token_id)
prompt_tokens = torch.tensor(prompt_tokens, device="cuda")
prompt_len = len(prompt_tokens)

def project_fn(x):
    # Project the first 10 tokens of all examples to the prompt
    x[:, :prompt_len] = prompt_tokens  
    return x  # Don't forget to return

tokens = model.sample(
    n_samples=8,
    num_steps=256,
    seq_len=1024,
    verbose=True,
    project_fn=project_fn
)

cond_text = model.tokenizer.batch_decode(tokens)

Distill models

Distill the pre-trained MDLM of Sahoo et al.

python src/sdtt/main.py \
    mode=train \
    parameterization.num_distill_steps=2 \
    model=dit-orig-small \
    time_conditioning=False \
    loader.global_batch_size=128 \
    loader.batch_size=32 \
    trainer.max_steps=80000 \
    hydra.run.dir="./outputs/distill_2_steps_from_hf_sm" \
    loader.num_workers=16 \
    compile=False \
    trainer.val_check_interval=5000 \
    data_preprocess.data_cache=./data_cache \
    wandb.project=debug

Distill a model you trained yourself

python src/sdtt/main.py \
    mode=train \
    parameterization.start_from_hf=False \
    model=dit-orig-medium \
    parameterization.checkpoint_path=<REPLACE_BY:path_to_mdlm_code>/outputs/openwebtext/mdlm_md/checkpoints/0-1000000.ckpt \
    parameterization.num_distill_steps=2 \
    time_conditioning=False \
    loader.global_batch_size=128 \
    loader.batch_size=16 \
    trainer.max_steps=80000 \
    hydra.run.dir="./outputs/distill_2_steps_md" \
    loader.num_workers=16 \
    compile=False \
    trainer.val_check_interval=5000 \
    data_preprocess.data_cache=./data_cache \
    wandb.project=debug

Sample from a distilled model (for evaluation)

python src/sdtt/main.py \
    mode=sample \
    parameterization.num_distill_steps=2 \
    parameterization.start_from_hf=False \
    parameterization.sampling.uncond.run=True \
    parameterization.sampling.cond_prefix.run=True \
    parameterization.sampling.uncond.num_steps=2 \
    parameterization.sampling.cond_prefix.num_steps=2 \
    model=dit-orig-medium \
    parameterization.checkpoint_path=<REPLACE_BY:path_to_mdlm_code>/outputs/openwebtext/mdlm_md/checkpoints/0-1000000.ckpt \
    time_conditioning=False \
    loader.global_batch_size=128 \
    loader.batch_size=32 \
    hydra.run.dir="./outputs/distill_2_steps_md" \
    trainer.val_check_interval=5000 \
    data_preprocess.data_cache=./data_cache \
    wandb.project=debug

Run evaluations

python src/sdtt/main.py \
    mode=eval \
    eval.ppl_with_ar.run=True \
    eval.mauve.run=True \
    eval.lambada_openai.run=True \
    hydra.run.dir="./outputs/distill_2_steps_md" \
    data_preprocess.data_cache=./data_cache \
    loader.num_workers=32 \
    compile=True \

Code structure

Citation

@article{deschenaux2024autoregressionfastllmsselfdistillation,
        title={Beyond Autoregression: Fast LLMs via Self-Distillation Through Time},
        author={Deschenaux, Justin and Gulcehre, Caglar}
        eprint={2410.21035},
        archivePrefix={arXiv},
        primaryClass={cs.LG},
        url={https://arxiv.org/abs/2410.21035}, 
      }

Acknowledgements

Our codebase is inspired by recent discrete diffusion language models projects. Namely, MDLM and SEDD.