Home

Awesome

Enhancing contextual understanding in large language models through contrastive decoding

Large language models (LLMs) tend to inadequately integrate input context during text generation, relying excessively on encoded prior knowledge in model parameters, potentially resulting in generated text with factual inconsistencies or contextually unfaithful content. LLMs utilize two primary knowledge sources: 1) prior (parametric) knowledge from pretraining, and 2) contextual (non-parametric) knowledge from input prompts. The study addresses the open question of how LLMs effectively balance these knowledge sources during the generation process, specifically in the context of open-domain question answering. To address this issue, we introduce a novel approach integrating contrastive decoding with adversarial irrelevant passages as negative samples to enhance robust context grounding during generation. Notably, our method operates at inference time without requiring further training. We conduct comprehensive experiments to demonstrate its applicability and effectiveness, providing empirical evidence showcasing its superiority over existing methodologies.

Development

First, to create an environment, run the following command using conda:

conda env create -f environment.yml

You will also need to make an editable install of Huggingface's transformers library since we will need to change the decoding function.

Once you have installed the library, you can simply need to swap the file src/contrastive_decoding/lib/transformers/utils.py in your local copy of the transformers' repository. The path of utils.py in the repository should be src/transformers/generation/

Then, you can start running experiments:

./scripts/run_nq.sh

Or alternatively, you can use the following code snippet:

CUDA_VISIBLE_DEVICES=0 python src/contrastive_decoding/run_qa_prompt.py \
 --model_name /home/ec2-user/data/Llama-7b-hf \
 --input_file ./data/nq_test.tsv \
 --eval_method CD \
 --n_examples 5 \
 --ret_path ./data/retrieval/nq_contriever_results.jsonl \
 --bf16 \
 --alpha 0.5 \
 --alias 'nq-alpha-0.5'

Here are a brief explanation of the args that can be used:

Security

See CONTRIBUTING for more information.

License

This library is licensed under the CC-BY-NC-4.0 License.

@Inproceedings{Zhao2024,
 author = {Zheng Zhao and Emilio Monti and Jens Lehmann and Haytham Assem},
 title = {Enhancing contextual understanding in large language models through contrastive decoding},
 year = {2024},
 url = {https://www.amazon.science/publications/enhancing-contextual-understanding-in-large-language-models-through-contrastive-decoding},
 booktitle = {NAACL 2024},
}