Home

Awesome

CoT-Influx

This repository contains the code of CoT-Influx introduced in our work: "Fewer is More: Boosting LLM Reasoning with Reinforced Context Pruning", published in EMNLP Main Conference 2024.

<div align=center> <img width=90% src="CoT-Influx.png"/> </div>

🌟 Abstract

Motivated by the observation that adding more concise CoT examples in the prompt can improve LLM reasoning performance, we propose CoT-Influx, which employs a coarse-to-fine pruner to maximize the input of effective and concise CoT examples. The pruner first selects as many crucial CoT examples as possible and then prunes unimportant tokens to fit the context window.

🌿 Citation

@article{huang2023fewer,
    title={Fewer is More: Boosting LLM Reasoning with Reinforced Context Pruning},
    author={Huang, Xijie and Zhang, Li Lyna and Cheng, Kwang-Ting and Yang, Mao},
    journal={arXiv preprint arXiv:2312.08901},
    year={2023}
}

🛠️ Preparation

Requirements

pip install -r requirements.txt

Huggingface Hub Login

pip install --upgrade huggingface_hub
huggingface-cli login

Preparing for pruner training data and prompt candidates for evaluation

🏃 Run

Pruner training on MRD$^3$

Evaluation on math reasoning dataset

To evaluate the few-shot reasoning performance of LLaMA2-7B with CoT-Influx on GSM8K, run the following command

CUDA_VISIBLE_DEVICES=0 python example_retrieval_pruner.py \
--base_model meta-llama/Llama-2-7b-hf \
--pruner_model ./pruner_ckpt/llama2_13b.pth \
--candidate_set ./mrd3/score_revise_difficulty_mrd3.json \
--method few_shot_cot --cot_shot_length 32 --add_16shot \
2>&1 | tee -a ./logs/llama2-7b-gsm8k.log

To evaluate LLaMA2-13B with CoT-Influx on GSM8K, run the following command

CUDA_VISIBLE_DEVICES=0 python example_retrieval_pruner.py \
--base_model meta-llama/Llama-2-13b-hf \
--pruner_model ./pruner_ckpt/llama2_13b.pth \
--candidate_set ./mrd3/score_increase_reasoning_mrd3.json \
--method few_shot_cot --cot_shot_length 24 --add_16shot \
2>&1 | tee -a ./logs/llama2-13b-gsm8k.log

To evaluate LLaMA2-13B with CoT-Influx on GSM8K, run the following command

CUDA_VISIBLE_DEVICES=0 python example_retrieval_pruner.py \
--base_model meta-llama/Llama-2-70b-hf \
--pruner_model ./pruner_ckpt/llama2_70b.pth \
--candidate_set ./mrd3/score_add_constraints_mrd3.json \
--method few_shot_cot --cot_shot_length 32 --add_8shot \
2>&1 | tee -a ./logs/llama2-70b-gsm8k.log

📚 Results and Logs

ModelEM (%) on GSM8KPruner weightsEvaluation logs
LLaMA2-7B15.85llama2_13b.pthlink
LLaMA2-13B32.22llama2_13b.pthlink
LLaMA2-70B59.59llama2_70b.pthlink

💌 Acknowledgement and Contact

This repo benefits from zero_shot_cot, WizardLM, LLM-Adapters, and OpenICL. Thanks for their wonderful works!

If you have any questions, feel free to contact Xijie HUANG (huangxijie1108 at gmail.com or xhuangbs at connect.ust.hk) and Li Lyna Zhang (lzhani at microsoft.com).