Home

Awesome

<a name="readme-top"></a>

Liger Kernel: Efficient Triton Kernels for LLM Training

<table style="width: 100%; text-align: center; border-collapse: collapse;"> <tr> <th style="padding: 10px;" colspan="2">Stable</th> <th style="padding: 10px;" colspan="2">Nightly</th> <th style="padding: 10px;">Discord</th> <th style="padding: 10px;">Build</th> </tr> <tr> <td style="padding: 10px;"> <a href="https://pepy.tech/project/liger-kernel"> <img src="https://static.pepy.tech/badge/liger-kernel" alt="Downloads (Stable)"> </a> </td> <td style="padding: 10px;"> <a href="https://pypi.org/project/liger-kernel"> <img alt="PyPI - Version" src="https://img.shields.io/pypi/v/liger-kernel?color=green"> </a> </td> <td style="padding: 10px;"> <a href="https://pepy.tech/project/liger-kernel-nightly"> <img src="https://static.pepy.tech/badge/liger-kernel-nightly" alt="Downloads (Nightly)"> </a> </td> <td style="padding: 10px;"> <a href="https://pypi.org/project/liger-kernel-nightly"> <img alt="PyPI - Version" src="https://img.shields.io/pypi/v/liger-kernel-nightly?color=green"> </a> </td> <td style="padding: 10px;"> <a href="https://discord.gg/gpumode"> <img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord"> </a> </td> <td style="padding: 10px;"> <div style="display: block;"> <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml"> <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build"> </a> </div> <div style="display: block;"> <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml"> <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build"> </a> </div> </td> </tr> </table> <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">

Installation | Getting Started | Examples | High-level APIs | Low-level APIs | Cite our work

<details> <summary>Latest News 🔥</summary> </details>

Liger Kernel is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. The kernel works out of the box with Flash Attention, PyTorch FSDP, and Microsoft DeepSpeed. We welcome contributions from the community to gather the best kernels for LLM training.

We've also added optimized Post-Training kernels that deliver up to 80% memory savings for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out how we optimize the memory.

Supercharge Your Model with Liger Kernel

Banner

With one line of code, Liger Kernel can increase throughput by more than 20% and reduce memory usage by 60%, thereby enabling longer context lengths, larger batch sizes, and massive vocabularies.

Speed UpMemory Reduction
Speed upMemory

Note:

Optimize Post Training with Liger Kernel

<p align="center"> <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/post-training.png" width="50%" alt="Post Training"> </p>

We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules.

from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
orpo_loss = LigerFusedLinearORPOLoss()
y = orpo_loss(lm_head.weight, x, target)

Examples

Use CaseDescription
Hugging Face TrainerTrain LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP
Lightning TrainerIncrease 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3
Medusa Multi-head LLM (Retraining Phase)Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP
Vision-Language Model SFTFinetune Qwen2-VL on image-text data using 4 A100s with FSDP
Liger ORPO TrainerAlign Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction

Key Features

Installation

Dependencies

CUDA

ROCm

Optional Dependencies

Note: Our kernels inherit the full spectrum of hardware compatibility offered by Triton.

To install the stable version:

$ pip install liger-kernel

To install the nightly version:

$ pip install liger-kernel-nightly

To install from source:

git clone https://github.com/linkedin/Liger-Kernel.git
cd Liger-Kernel

# Install Default Dependencies
# Setup.py will detect whether you are using AMD or NVIDIA
pip install -e .

# Setup Development Dependencies
pip install -e ".[dev]"

Getting Started

There are a couple of ways to apply Liger kernels, depending on the level of customization required.

1. Use AutoLigerKernelForCausalLM

Using the AutoLigerKernelForCausalLM is the simplest approach, as you don't have to import a model-specific patching API. If the model type is supported, the modeling code will be automatically patched using the default settings.

from liger_kernel.transformers import AutoLigerKernelForCausalLM

# This AutoModel wrapper class automatically monkey-patches the
# model with the optimized Liger kernels if the model is supported.
model = AutoLigerKernelForCausalLM.from_pretrained("path/to/some/model")

2. Apply Model-Specific Patching APIs

Using the patching APIs, you can swap Hugging Face models with optimized Liger Kernels.

import transformers
from liger_kernel.transformers import apply_liger_kernel_to_llama

# 1a. Adding this line automatically monkey-patches the model with the optimized Liger kernels
apply_liger_kernel_to_llama()

# 1b. You could alternatively specify exactly which kernels are applied
apply_liger_kernel_to_llama(
  rope=True,
  swiglu=True,
  cross_entropy=True,
  fused_linear_cross_entropy=False,
  rms_norm=False
)

# 2. Instantiate patched model
model = transformers.AutoModelForCausalLM("path/to/llama/model")

3. Compose Your Own Model

You can take individual kernels to compose your models.

from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
import torch.nn as nn
import torch

model = nn.Linear(128, 256).cuda()

# fuses linear + cross entropy layers together and performs chunk-by-chunk computation to reduce memory
loss_fn = LigerFusedLinearCrossEntropyLoss()

input = torch.randn(4, 128, requires_grad=True, device="cuda")
target = torch.randint(256, (4, ), device="cuda")

loss = loss_fn(model.weight, input, target)
loss.backward()

High-level APIs

AutoModel

AutoModel VariantAPI
AutoModelForCausalLMliger_kernel.transformers.AutoLigerKernelForCausalLM

Patching

ModelAPISupported Operations
LLaMA 2 & 3liger_kernel.transformers.apply_liger_kernel_to_llamaRoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
LLaMA 3.2-Visionliger_kernel.transformers.apply_liger_kernel_to_mllamaRoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Mistralliger_kernel.transformers.apply_liger_kernel_to_mistralRoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Mixtralliger_kernel.transformers.apply_liger_kernel_to_mixtralRoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Gemma1liger_kernel.transformers.apply_liger_kernel_to_gemmaRoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Gemma2liger_kernel.transformers.apply_liger_kernel_to_gemma2RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Qwen2, Qwen2.5, & QwQliger_kernel.transformers.apply_liger_kernel_to_qwen2RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Qwen2-VLliger_kernel.transformers.apply_liger_kernel_to_qwen2_vlRMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy
Phi3 & Phi3.5liger_kernel.transformers.apply_liger_kernel_to_phi3RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy

Low-level APIs

Model Kernels

KernelAPI
RMSNormliger_kernel.transformers.LigerRMSNorm
LayerNormliger_kernel.transformers.LigerLayerNorm
RoPEliger_kernel.transformers.liger_rotary_pos_emb
SwiGLUliger_kernel.transformers.LigerSwiGLUMLP
GeGLUliger_kernel.transformers.LigerGEGLUMLP
CrossEntropyliger_kernel.transformers.LigerCrossEntropyLoss
Fused Linear CrossEntropyliger_kernel.transformers.LigerFusedLinearCrossEntropyLoss

Alignment Kernels

KernelAPI
Fused Linear CPO Lossliger_kernel.chunked_loss.LigerFusedLinearCPOLoss
Fused Linear DPO Lossliger_kernel.chunked_loss.LigerFusedLinearDPOLoss
Fused Linear ORPO Lossliger_kernel.chunked_loss.LigerFusedLinearORPOLoss
Fused Linear SimPO Lossliger_kernel.chunked_loss.LigerFusedLinearSimPOLoss

Distillation Kernels

KernelAPI
KLDivergenceliger_kernel.transformers.LigerKLDIVLoss
JSDliger_kernel.transformers.LigerJSD
Fused Linear JSDliger_kernel.transformers.LigerFusedLinearJSD

Experimental Kernels

KernelAPI
Embeddingliger_kernel.transformers.experimental.LigerEmbedding
Matmul int2xint8liger_kernel.transformers.experimental.matmul

Contributing, Acknowledgements, and License

Sponsorship and Collaboration

Contact

Cite this work

Biblatex entry:

@article{hsu2024ligerkernelefficienttriton,
      title={Liger Kernel: Efficient Triton Kernels for LLM Training},
      author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
      year={2024},
      eprint={2410.10989},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2410.10989},
      journal={arXiv preprint arXiv:2410.10989},
}

Star History

Star History Chart

<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;"> <a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;"> ↑ Back to Top ↑ </a> </p>