Awesome
Hierarchical Context Merging: Better Long Context Understanding for Pre-trained LLMs (ICLR 2024)
Official implementation of "Hierarchical Context Merging: Better Long Context Understanding for Pre-trained LLMs" by Woomin Song, Seunghyuk Oh, Sangwoo Mo, Jaehyung Kim, Sukmin Yun, Jung-Woo Ha, Jinwoo Shin
TL;DR: We propose a memory-efficient method to extend the context limit of large language models.
<p align="center"> <img src=figure/concept_figure.png width="900"> </p>1. Dependencies
conda create -n homer python=3.10 -y
conda activate homer
conda install pytorch==2.3.1+cu118 torchvision==0.18.1+cu118 torchaudio==2.3.1+cu118 pytorch-cuda=11.8 -c pytorch -c nvidia
pip install accelerate==0.32.1 matplotlib==3.9.1 sentencepiece==0.2.0 tqdm==4.66.4 transformers==4.42.4 protobuf==5.27.2
pip install flash-attn==2.6.1 --no-build-isolation
2. Resources
The following data files are provided here.
- Passkey retrieval data for 4k / 8k / 16k / 32k context lengths
- Example long document from PG19 test set
The bias files for calibrating Llama-2 models can be found here.
3. Usage
Model loading
Use the LlamaForCausalLM
class provided in src/homer/modeling_llama.py
.
from homer.modeling_llama import LlamaForCausalLM
# Setup arguments for HOMER
homer_args = {
"max_chunk_len": 2048,
"target_len": 2048,
"layers_warmup": 12,
"bias_path": "/path/to/bias_file.pt",
}
# Load model
model = LlamaForCausalLM.from_pretrained(
meta-llama/Llama-2-7b-hf,
device_map="auto",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
homer_args=homer_args,
)
The arguments in homer_args are as follows.
max_chunk_len
: Maximum length for a single chunk. Typically set to 1/2 of the original model's context limit (e.g., 2048 for plain Llama-2, 4096 for YaRN with scale factor 2)target_len
: Length of the resulting KV-cache after applying HOMER. Typically set to the same value asmax_chunk_len
.layers_warmup
: Number of warmup layers, where merging does not happen.bias_path
: Path to the bias file used for calibration.
Inference and Generation
Use the create_homer_prefix()
method to perform hierarchical merging and create the homer_prefix
, which is the compact KV-cache obtained after applying HOMER.
To perform inference conditioned on the compressed prompt, forward the remaining inputs (i.e., inputs without the prompt) along with the homer_prefix
. See src/perplexity.py
for more concrete examples.
homer_prefix = model.create_homer_prefix(
prefix_ids, context_ids, suffix_ids
)
output = model(input_ids, homer_prefix=homer_prefix)
To perform generation, pass the homer_prefix
to the conventional generate()
method. See src/passkey.py
for more concrete examples.
homer_prefix = model.create_homer_prefix(
prefix_ids, context_ids, suffix_ids
)
output = model.generate(homer_prefix=homer_prefix)
Alternative method for model loading
We also provide a patch()
function that patches an existing HuggingFace LlamaForCausalLM object to support HOMER.
from transformers import LlamaForCausalLM
from homer.patch_llama import patch as patch_llama_for_homer
# Load model
model = LlamaForCausalLM.from_pretrained(
meta-llama/Llama-2-7b-hf,
device_map="auto",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
)
# Patch for HOMER
patch_llama_for_homer(
model,
max_chunk_len=2048,
target_len=2048,
layers_warmup=12,
bias_path="/path/to/bias_file.pt",
)
4. Language Modeling
# Plain Llama
python src/perplexity.py --data_path /path/to/pg19_test_sample.txt --model_path meta-llama/Llama-2-7b-hf --model_type plain
# Plain YaRN
python src/perplexity.py --data_path /path/to/pg19_test_sample.txt --model_path meta-llama/Llama-2-7b-hf --model_type yarn --scale 2
# HOMER
python src/perplexity.py --data_path /path/to/pg19_test_sample.txt --model_path meta-llama/Llama-2-7b-hf --model_type homer --bias_path /path/to/7b_homer.pt
# HOMER + YaRN
python src/perplexity.py --data_path /path/to/pg19_test_sample.txt --model_path meta-llama/Llama-2-7b-hf --model_type homer_yarn --scale 2 --bias_path /path/to/7b_homer_yarn_scale_2.pt
5. Passkey Retrieval
# Plain Llama
python src/passkey.py --data_path /path/to/passkey_8192_tokens.jsonl --model_path meta-llama/Llama-2-7b-chat-hf --model_type plain
# Plain YaRN
python src/passkey.py --data_path /path/to/passkey_8192_tokens.jsonl --model_path meta-llama/Llama-2-7b-chat-hf --model_type yarn --scale 2
# HOMER
python src/passkey.py --data_path /path/to/passkey_8192_tokens.jsonl --model_path meta-llama/Llama-2-7b-chat-hf --model_type homer --bias_path /path/to/7b_homer_chat.pt
# HOMER + YaRN
python src/passkey.py --data_path /path/to/passkey_8192_tokens.jsonl --model_path meta-llama/Llama-2-7b-chat-hf --model_type homer_yarn --scale 2 --bias_path /path/to/7b_homer_yarn_chat_scale_2.pt
6. Citation
@article{song2024hierarchical,
title={Hierarchical context merging: Better long context understanding for pre-trained LLMs},
author={Song, Woomin and Oh, Seunghyuk and Mo, Sangwoo and Kim, Jaehyung and Yun, Sukmin and Ha, Jung-Woo and Shin, Jinwoo},
journal=ICLR,
year={2024}
}