Home

Awesome

Distilling Chain-of-Thought Reasoning from code-davinci-002 to FlanT5

Implementation of Yao Fu, Hao Peng, Litu Ou, Ashish Sabharwal, Tushar Khot. Specializing Smaller Language Models towards Multi-Step Reasoning. ICML 2023. [Arxiv]

Download data at Google Drive

After downloading the data, put it under processed_data/ folder because all data are processed and stored as .pkl files.

A lot of the engineering efforts in this work is not modeling, but data engineering, mostly about processing the data into the four following formats that is important for imbuing the model with in-context and zero-shot abilities. See figure 1B in the paper for details.

We strongly recommend runing notebooks/inspect_processed_data.ipynb to get a sense at what the data looks like. It gives an example about how in-context chain-of-thought data looks like.

The actual training script is pretty simple train_distill_simple.py. Most of the efforts go to data engineering, hyperparameter search, and evaluation.

The following is a quickstart code using FlanT5 base model. We did not have time to implement DeepSpeed/ FairScale/ Pytorch FSDP because we were in a rush when developing this work. Yet wrapping the model with DeepSpeed should be pretty straightforward. If you have done this, please submit a pull request and we will be happy to merge it :)

Quickstart:

pip install -r requirements.txt

# inspect data 
# see notebooks/inspect_processed_data.ipynb

# run a small model 
model_version=0.0.5.0 # base model FlanT5 780m
nohup python -u train_distill_simple.py\
    model_version=${model_version}\
    gpu_id=\'0\'\
    base_model=\'google/flan-t5-base\'\
    batch_size=250m\
    grad_accum_steps=3\
    save_per_step=1000\
    log_interval=2\
    lr=0.0005\
    &> logs/beta_${model_version}.log &
tail -f logs/beta_${model_version}.log

Notebooks for inspecting the processed data

Notebooks for visualization

Notebooks for prompting FlanT5

Scripts

Distillation

TODO: