Home

Awesome

<p align="center" width="100%"> <img src="fig/Chunkllama.png" alt="chunkllama" style="width: 90%; min-width: 300px; display: block; margin: auto;"> </p>

Training-Free Long-Context Scaling of Large Language Models

Huggingface Models Data Paper

Overview

Dual chunk attention is a training-free and effective method for extending the context window of large language models (LLMs) to more than 8x times their original pre-training length. We refer to the Llama-based model with dual chunk attention as ChunkLlama. DCA can be seamlessly integrated with (1) popular extrapolation methods such as Positional Interpolation (PI), NTK-Aware RoPE, and YaRN; and (2) widely-used libraries for memory-efficient inference like FlashAttention and vLLM.

Due to the high cost of continual pretraining on longer sequences, previously released long-context models are typically limited to scales of 7B/13B. We demonstrate that by applying DCA to Llama-2/3 70B, the model exhibits surprising extrapolation capabilities (100k context length) and a very strong understanding of practical long-context tasks.

Updates

# step 0: Editable installation
cd vllm & pip install -e .

# step1: Modify the config.json file for your model by adding:
{
    "architectures": [
        "LlamaForCausalLM"
    ],
    // ...
    "max_position_embeddings": 131072, // extrapolation length
    "dual_chunk_attention_config": {
        "chunk_size": 8192, // training length (32768 for qwen2)
        "local_size": 512,
        "original_max_position_embeddings": 8192 // training length
     }
}

# step2: VLLM inference
from vllm import LLM, SamplingParams

model_path = "/path/to/llama3-8b-instruct" 
llm = LLM(model=model_path, tensor_parallel_size=1, enforce_eager=True, enable_chunked_prefill=False, max_num_batched_tokens=131072)

passkey_1  = "123456"
passkey_2  = "654321"
prompt1 = f"There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\nThe pass key is {passkey_1}. Remember it. {passkey_1} is the pass key.\n " + \
    "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 4000 + \
    "\nWhat is the pass key?\nThe passkey is " # The prompt is 100k long. You can try longer prompt by increasing the length.
prompt2 = prompt1.replace(passkey_1, passkey_2)
prompts = [f"{prompt1}", f"{prompt2}"]

sampling_params = SamplingParams(top_p=0.8, temperature=0.7, repetition_penalty=1.05, top_k=10, max_tokens=100)
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Generated text: {generated_text!r}")
(Usage for standard self-attention)
from flash_decoding_llama import replace_with_flashdecoding
replace_with_flashdecoding(max_prompt_length) # max_prompt_length is the maximum input length, e.g. 131072

we suggest using flash_attn >=2.5.3,<2.6.0

Model4k8k16k32k64k96k128k160k
ChunkLlama3-8b9.048.718.618.628.959.4310.0410.66
ChunkLlama3-70b5.365.165.145.145.215.325.405.45

ChunkLlama3-8b achieves 100% retrieval accuracy across all document depths. Our few-shot results on the base model and zero-shot results on chat models show that ChunkLlama3-70b achieves performance on par with GPT-4 (2023/06/13) and Llama2 Long 70b (Detailed results).

<p align="center" width="100%"> <img src="fig/merge_needle_mistral.png" alt="mistral_needle" style="width: 80%; min-width: 300px; display: block; margin: auto;"> </p>

🚀Quick Start

As a training-free method, only one line needs to be added to your original inference code for the Llama2 model:

# `transformers==4.37.2`
from chunkllama_attn_replace import replace_with_chunkllama 
# flash decoding: flash_decoding_chunkllama import replace_with_chunkllama
replace_with_chunkllama(pretraining_length=4096) # pretraining_length=8192 if you are using Llama3

For other foundation models:

from chunkllama_attn_replace import replace_with_chunkmistral, replace_with_chunkmixtral
from chunkqwen_attn_replace import replace_with_chunkqwen

replace_with_chunkmistral(pretraining_length=32768) # Mistral-v0.2
replace_with_chunkmixtral(pretraining_length=32768) # Mixtral MOE model
replace_with_chunkqwen(pretraining_length=32768) # Qwen 1.5

Full inference code

from transformers import AutoTokenizer, AutoModelForCausalLM
from flash_decoding_chunkllama import replace_with_chunkllama
# flash decoding: from chunkllama_attn_replace import replace_with_chunkllama

##### add this line #####
replace_with_chunkllama(pretraining_length=4096)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", attn_implementation="flash_attention_2", trust_remote_code=True, torch_dtype=torch.bfloat16)
inputs = tokenizer("Long...docs\n Q: How to extend the context window of LLMs? ", return_tensors="pt")
prompt_length = inputs.input_ids.size()[-1]
output_ids = model.generate(**inputs, max_new_tokens=64)[0]
print(tokenizer.decode(output_ids)[prompt_length:])

Chat with a lengthy PDF file

We have provided a collection of influential papers on long-context scaling of LLMs in the Popular_PDFs directory. By using the --pdf parameter, you can access the latest advancements in this field through ChunkLlama⭐.

<p align = "center"> <img src="fig/sample_chain.gif" width="95%" alt="examples" align=center loop=infinite/> </p> All of these papers are released recently and are impossible to be used during pretraining.

Usage Requirements

  1. Prepare the environment for transformers+flash-attention2.
pip install -r requirements.txt
pip install flash-attn --no-build-isolation (FlashAttention >= 2.5.0)
  1. Download the pretraining weights (Extended ctx means the context length enabled by DCA).
Supported ModelsExtended ctx
Base Models
Llama-2-7b-hf (4k)32k
Llama-2-13b-hf (4k )32k
Llama-2-70b-hf (4k)128k
Meta-Llama-3-8B (8k)96k
Meta-Llama-3-70B (8k)200k+
Together's LLaMA-2-7b-32k200k
SFT Models
Llama-2-7b-chat-hf (4k)32k
Llama-2-13b-chat-hf (4k)32k
Llama-2-70b-chat-hf (4k)128k
Meta-Llama-3-8B-Instruct (8k)96k
Meta-Llama-3-70B-Instruct (8k)200k+
Vicuna-1.5-7b-16k200k
Vicuna-1.5-13b-16k200k
Mixtral 8x7b & Mistral 7b200k+
Qwen1.5 中文200k
  1. Deploy your own demo. We provide three examples of how to employ DCA on popular LLMs in run_chunkllama_100k.py, run_together_200k.py and run_vicuna_200k.py.

Run the demo:

python run_chunkllama_100k.py --max_length 16000 --scale 13b (7b/13b/70b) --pdf Popular_PDFs/longlora.pdf

If you have OOM problems when dealing with longer input or larger models, we recommend using Tensor Parallelism:

deepspeed run_chunkllama_100k_ds.py --max_length 64000  --scale 13b (7b/13b/70b) --pdf Popular_PDFs/longlora.pdf

📌 Notice: We have found that although 7B models can achieve low perplexity on long contexts, they often make mistakes in practical tasks, including those with fine-tuned versions. Therefore, we recommend using the larger 13B (ChunkLlama-13b, Chunk-Vicuna-13b) or 70B (ChunkLlama-70B) models for higher accuracy.

Fine-tuning

ChunkLlama can be further improved by fine-tuning on long conversations. We further train ChunkLlama on with a context window of 16k on concatenated dialogues from the previous SFT datasets ShareGPT and AlpacaGPT4. The data we use is available here

cd fine-tune
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export WANDB_MODE=dryrun

python -m torch.distributed.run --nproc_per_node=8 \
         train_chunkllama_16k.py \
        --model_name_or_path meta-llama/llama-2-7b-chat-hf \
        --bf16 \
        --output_dir checkpoints/chunkllama-7b-release \
        --max_steps 1600    \
        --per_device_train_batch_size 1 \
        --per_device_eval_batch_size 1  \
        --gradient_accumulation_steps 2 \
        --evaluation_strategy no \
        --save_strategy steps \
        --save_steps 400  \
        --save_total_limit 2 \
        --learning_rate 2e-5 \
        --weight_decay 0.  \
        --warmup_ratio 0.03  \
        --lr_scheduler_type "cosine" \
        --logging_steps 1  \
        --fsdp "full_shard auto_wrap" \
        --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
        --tf32 True  \
        --model_max_length 16384  \
        --gradient_checkpointing True  \
        --lazy_preprocess True \
        --pretraining_length 4096

You can change --model_name_or_path, --output_dir to your own directory. In our experiments, we directly train the chat version of Llama2, you can also use its base version.

Experiments

This section contains the data and code for validating ChunkLlama on different types of long-context tasks.

Perplexity validation on PG19

cd ppl
python test_ppl.py --seq_len 16384 --scale 13b (7b/13b/70b) --data_path pg19_llama2.validation.bin

where --seq_len 16384 denotes the length of input prompts. We use tokenized the tokenized validation split of PG19 processed by longlora. The raw data and tokenized data are in ppl folder.

Passkey Retrieval

We provide a manner to test the passkey retrieval accuracy. For example,

cd passkey
python test_passkey.py --seq_len 16384 --scale 13b (7b/13b/70b)

Needle In A HayStack

We provide a manner to test the passkey retrieval accuracy. For example,

cd need_in_a_haystack
# the following command will generate a jsonl file
python retrieve_needle.py --max_length 192k --model mistral --pretraining_length 32384
# for Llama: python retrieve_needle.py --max_length 192k --model Llama2 --pretraining_length 4096
# get the figure
python draw.py 

Few-shot Learning

The experimental settings of few-shot learning are the same as that in Llama2 Long. We use 4 popular long-context benchmarks: NarrativeQA, QMSum, Qasper, and Quality. We also release the data together with in-context examples in few-shot-data. We report the results on their validation sets. The in-context examples are randomly selected from the training set.

cd few-shot
python test_few_shot.py --data_path data/few_shot_quality.jsonl --max_length 16k --scale 13b 

where --data_path denotes the path to the dataset assuming the data is saved in few-shot/data/. The generation results will be saved to Predictions/Chunkllama-13b16k/few_shot_quality.json

We use the validation scripts provided by Scrolls to obtain the results:

python auto_eval.py   --dataset_name quality  --metrics_output_dir ./  --predictions Predictions/Chunkllama-13b16k/few_shot_quality.json  --test_data_file data/few_shot_quality.jsonl

Zero-shot Learning

We also test our method on the chat version of Llama2 on zero-shot learning tasks. Considering the challenges of fair evaluation on open-ended tasks. We select 4 closed-ended tasks from L-Eval with diverse input lengths ranging from 3k to 27 tokens.

cd zero-shot
python test_zero_shot.py --task_path Closed-ended-tasks/coursera.jsonl --max_length 16k --scale 13b

The experimental settings and evaluation scripts are the same as those in the official repository of L-Eval.

python Evaluation/auto_eval.py --pred_file Predictions/Chunkllama-13b16k/coursera.jsonl 

ChunkLlama3

PPL on PG19 validation set:

Model4k8k16k32k64k96k128k160k
Llama3-8b9.048.7178.88>100>100>100>100>100
ChunkLlama3-8b9.048.718.618.628.959.4310.0410.66
Llama3-70b5.365.16>100>100>100>100>100>100
ChunkLlama3-70b5.365.165.145.145.215.325.405.45

Few-shot results on 4 research benchmarks:

ModelNarrativeQA(0-shot)Qasper(2-shot)QuALITY(2-shot)QMSum(1-shot)
ChunkLlama3-8b27.430.552.615.4
Llama2 Long-7b21.927.843.214.9
ChunkLlama3-70b33.733.175.416.0
Llama2 Long-70b30.935.779.716.5

Zero-shot results (with Chat models) on L-Eval:

ModelTOEFLQuALITYCourseraSFiction
ChunkLlama3-8b83.2763.8656.2470.31
ChunkLlama3-70b84.7582.1776.8875.78
GPT4-32k (2023)84.3882.1775.5874.99

Acknowledgements

We sincerely appreciate the assistance provided by the following people (works) for ChunkLlama:

Citation

@misc{an2024trainingfree,
      title={Training-Free Long-Context Scaling of Large Language Models}, 
      author={Chenxin An and Fei Huang and Jun Zhang and Shansan Gong and Xipeng Qiu and Chang Zhou and Lingpeng Kong},
      year={2024},
      eprint={2402.17463},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

License