Home

Awesome

X-LoRA

Mixture of LoRA Experts: Leverage the power of fine-tuned LoRA experts by employing a mixture of experts, or MoE technique.

X-LoRA works by learning scaling values for LoRA adapters. These learned scalings values are used to gate the LoRA experts in a dense fashion. Additionally, all LoRA adapters and the base model are frozen, allowing efficient fine tuning due to a low parameter count.

X-LoRA is easily applied to any HuggingFace Transformers model. Please see our weights, here and our paper.

Token-by-token scalings

Token-by-token scalings

Advantages and features

Architecture

<p align="center"> <img src="./res/general_arch_v5.png" alt="General Architecture" width=75%/> </p>

See the examples folder for some examples of how to get started with X-LoRA.

Efficient Inference Support

Mistral.rs is an inference framework which supports X-LoRA! To use it, follow the installation instructions and run the following command to start up an X-LoRA inference platform.

./mistralrs-server --port 1234 x-lora-mistral -o ordering.json

Base and X-LoRA Huggingface model IDs may be specified through command line switches to use your own models. Please see the Github page for further details.

Installation

Pending a pip release, run the following command to install X-LoRA.

pip install git+https://github.com/EricLBuehler/xlora.git

Examples

Excerpt from this example.

Converting a model

import torch
import xlora
from transformers import AutoConfig, AutoModelForCausalLM # type: ignore

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
)

config = AutoConfig.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="auto",
)

### Convert the model to X-LoRA
model_created = xlora.add_xlora_to_model(
    model=model,
    xlora_config=xlora.xLoRAConfig(
        config.hidden_size,
        base_model_id="mistralai/Mistral-7B-Instruct-v0.1",
        xlora_depth=8,
        device=torch.device("cuda"),
        adapters={
            "adapter_1": "./path/to/the/checkpoint/",
            "adapter_2": "./path/to/the/checkpoint/",
            "adapter_n": "./path/to/the/checkpoint/",
        },
    ),
    verbose=True,
)

Loading a trained X-LoRA model from scratch

import torch
import xlora
from transformers import AutoConfig, AutoModelForCausalLM # type: ignore

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
)

config = AutoConfig.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="auto",
)

model_created = xlora.from_pretrained(
    "./path/to/saved/model",
    model,
    "cuda",
)

Loading a trained X-LoRA model with a convenience function

import torch
from xlora.xlora_utils import load_model  # type: ignore

XLoRA_model_name = "myuser/repo"

model_loaded, tokenizer = load_model(
    model_name=XLoRA_model_name,
    device="cuda:0",
    dtype=torch.bfloat16,
)

Scalings logging

# Enable scalings logging and begin a log
model_created.enable_scalings_logging()

# Run forward passes to accumulate a log

# Write the log to a file, or multiple.
model_created.flush_log_scalings("./path/to/output/file")

# Get a shallow copy of the scalings
log_copy = model_created.get_scalings_log()

# Disable scalings logging
model_created.disable_scalings_logging()

# Clear the scalings log
model_created.clear_scalings_log()

# Get the latest scalings prediction
scalings_pred = model_created.get_latest_scalings()

# Load the scalings log from a file, or multiple automatically.
loaded_log = xlora.xlora_utils.load_scalings_log("./path/to/output/file", verbose=True)

Trainable parameters

model: xLoRAModel = ... # Load the model

num_trainable, num_all_params = model.get_nb_trainable_parameters()

model.print_trainable_parameters()

Setting trainability of adapters dynamically

model: xLoRAModel = ... # Load the model

# Use trainable adapters: mark all adapters as trainable
model.set_use_trainable_adapters(True)

# Get the current status of the trainable adapters, in this case returning True
model.get_use_trainable_adapters()

Setting and resetting the scaling pass value

model: xLoRAModel = ... # Load the model

# Set the scaling pass value to 0, meaning that no adapters will contribute to the scaling pass output
model.set_scaling_pass_value(0)

# Allow the model to use the default scaling pass value
model.set_scaling_pass_value(None)

Setting and getting the global LoRA weight

model: xLoRAModel = ... # Load the model

# Multiply the output of each LoRA adapter by 2, additionally to the scalings.
model.set_global_scaling_weight(2)

# Returns 2
res = model.get_global_scaling_weight()

Setting and getting the top-k lora value

# Use the top 2 lora experts
model_created.set_topk_lora(2)

# Returns 2
res = model_created.get_topk_lora()

API

The X-LoRA API is composed of 3 parts: the "Global API", the "Model API" and the "Utility API". Generally the global API is used to create X-LoRA models and the model API is used to interface with the models while the Utility API provides useful utility functions.

X-LoRA Config

The X-LoRA Config saves the full configuration of an X-LoRA model.

Args:
    hidden_size (`int`):
        Hidden size of the base model.
    device (`torch.device`):
        Device for the X-LoRA classifier.
    enable_softmax (`bool`, *optional*, defaults to `True`):
        Enable softmax application for the X-LoRA classifier.
    enable_softmax_topk (`bool`, *optional*, defaults to `False`):
        Enable softmax application for the top-k LoRA adapters. Mutually exclusive to `enable_softmax` and must only be set if `top_k_lora` is.
    softmax_temperature (`float`, *optional*, defaults to 1.0):
        Softmax temperature, lower yields sharper predictions
    layerwise_scalings (`bool`, *optional*, defaults to `False`):
        Generate scalings for each layer.
    top_k_lora (`int`, *optional*, defaults to None):
        Sparsely select the top_k LoRA experts instead of the default dense method.
    xlora_depth (`int`, *optional*, defaults to 1):
        Depth of the X-LoRA classifier.
    xlora_size (`int`, *optional*, defaults to 2048):
        Hidden size of the X-LoRA classifier, irrelevant if `xlora_depth=1`.
    enable_relu_and_dropout (`bool`, *optional*, defaults to `True`):
        Enable ReLU activation and Dropout application of the X-LoRA classifier.
    use_bias (`bool`, *optional*, defaults to `True`):
        Enable bias in X-LoRA classifier.
    xlora_dropout_p (`float`, *optional*, defaults to 0.2):
        Dropout probability of the X-LoRA classifier, irrelevant if `xlora_depth=1` or `enable_relu_and_dropout=False`.
    stop_token_id (`int`, *optional*):
        The id of the stop token for the input. If this is None, the sequence length is calculated using the attention mask.
    use_trainable_adapters (`bool`, *optional*, defaults to False):
        Make the adapters trainable.
    scaling_pass_value (`float`, *optional*, defaults to 0):
        Scaling pass value.
    global_scaling_weight (`float`, *optional*, defaults to 1):
        Weight to multiply output of each LoRA adapter by.

Global API

Utility API

Model API

Scalings

Trainable parameters

Setting the trainable adapters

Top-k

Original paper and citation

Cite this work as:

@article{Buehler_XLoRA_2024,
    title   = {X-LoRA: Mixture of Low-Rank Adapter Experts, a Flexible Framework for Large Language Models with Applications in Protein Mechanics and Design},
    author  = {E.L. Buehler, M.J. Buehler},
    journal = {},
    year    = {2024},
    volume  = {},
    pages   = {},
    url     = {https://arxiv.org/abs/2402.07148}
}

Contributing

Please run make style before submitting a PR.