Home

Awesome

CritiPrefill

Quick Start

Install

pip install -e . && pip install flash_attn==2.5.8 --no-build-isolation

Usage

from criti_prefill.modeling_patch import replace_llama_eattention, criti_config

model = LlamaForCausalLM.from_pretrained(args.model_name_or_path, 
                                             device_map=device,
                                             torch_dtype=dtype,
                                             attn_implementation="flash_attention_2"
                                             )

criti_config(model,
             segment_size=args.segment_size,
             threshold_len=args.threshold_len,
             block_size=args.block_size,
             budgets=args.budgets,
             layer_fusion=args.layer_fusion,
             layer_skip=args.layer_skip)

Experiments

Time to first token(TTFT) is one of the most intuitive metrics for user experience, yet it tends to be significantly slower compared to decoding time.

CritiPrefill can significantly reduce the TTFT while maintaining generation quality.