



bitnet PyTorch Implementation of the linear methods and model from the paper "BitNet: Scaling 1-bit Transformers for Large Language Models"

Paper link:

BitLinear = tensor -> layernorm -> Binarize -> abs max quantization -> dequant

"The implementation of the BitNet architecture is quite simple, requiring only the replacement of linear projections (i.e., nn.Linear in PyTorch) in the Transformer. " -- BitNet is really easy to implement just swap out the linears with the BitLinear modules!




pip3 install bitnet


We have a vast selection of example scripts here and in the examples folder:, let me know if you want assistance with a use-case in the discord!


import torch

from bitnet import BitLinear

# Input
x = torch.randn(10, 1000, 512)

# BitLinear layer
layer = BitLinear(512, 400)

# Output
y = layer(x)



import torch
from bitnet import BitLinearNew

# Create a random tensor of shape (16, 10)
x = torch.randn(16, 1000, 512)

# Create an instance of the BitLinearNew class with input size 10, output size 20, and 2 groups
layer = BitLinearNew(

# Perform a forward pass through the BitLinearNew layer with input x
output = layer(x)

# Print the output tensor


# Import the necessary libraries
import torch
from bitnet import BitNetTransformer

# Create a random tensor of integers
x = torch.randint(0, 20000, (1, 1024))

# Initialize the BitNetTransformer model
bitnet = BitNetTransformer(
    num_tokens=20000,  # Number of unique tokens in the input
    dim=1024,  # Dimension of the input and output embeddings
    depth=6,  # Number of transformer layers
    heads=8,  # Number of attention heads
    ff_mult=4,  # Multiplier for the hidden dimension in the feed-forward network

# Pass the tensor through the transformer model
logits = bitnet(x)

# Print the shape of the output


This Attention has been modified to use BitLinear instead of the default linear projection. It's also using Multi-Grouped Query Attention instead of regular multi-head attention for faster decoding and longer context handling.

import torch
from bitnet import BitMGQA

# Create a random tensor of shape (1, 10, 512)
x = torch.randn(1, 10, 512)

# Create an instance of the BitMGQA model with input size 512, 8 attention heads, and 4 layers
gqa = BitMGQA(512, 8, 4)

# Pass the input tensor through the BitMGQA model and get the output and attention weights
out, _ = gqa(x, x, x, need_weights=True)

# Print the shapes of the output tensor and attention tensor


import torch
from bitnet import BitFeedForward

# Create a random input tensor of shape (10, 512)
x = torch.randn(10, 512)

# Create an instance of the BitFeedForward class with the following parameters:
# - input_dim: 512
# - hidden_dim: 512
# - num_layers: 4
# - swish: True (use Swish activation function)
# - post_act_ln: True (apply Layer Normalization after each activation)
# - dropout: 0.1 (apply dropout with a probability of 0.1)
ff = BitFeedForward(512, 512, 4, swish=True, post_act_ln=True, dropout=0.1)

# Apply the BitFeedForward network to the input tensor x
y = ff(x)

# Print the shape of the output tensor y
print(y)  # torch.Size([10, 512])


from bitnet import BitNetInference

bitnet = BitNetInference()
bitnet.load_model("../model_checkpoint.pth")  # Download model
output_str = bitnet.generate("The dog jumped over the ", 512)

Huggingface Usage

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from bitnet import replace_linears_in_hf

# Load a model from Hugging Face's Transformers
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Replace Linear layers with BitLinear

# Example text to classify
text = "Replace this with your text"
inputs = tokenizer(
    text, return_tensors="pt", padding=True, truncation=True, max_length=512

# Perform inference
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)

# Process predictions
predicted_class_id = predictions.argmax().item()
print(f"Predicted class ID: {predicted_class_id}")

# Optionally, map the predicted class ID to a label, if you know the classification labels
# labels = ["Label 1", "Label 2", ...]  # Define your labels corresponding to the model's classes
# print(f"Predicted label: {labels[predicted_class_id]}")

Drop in Replacement for Pytorch Models

import torch
from torch import nn
from bitnet import replace_linears_in_pytorch_model

# Define a simple model
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.Linear(20, 30),

print("Before replacement:")

# Replace nn.Linear with BitLinear

print("After replacement:")

# Now you can use the model for training or inference
# For example, pass a random input through the model
input = torch.randn(1, 10)
output = model(input)

Optimized Cuda Kernel

python setup.py build_ext --inplace

import torch
import gemm_lowbit_ext  # This imports the compiled module

# Example usage
a = torch.randn(10, 20, dtype=torch.half, device='cuda')  # Example tensor
b = torch.randn(20, 30, dtype=torch.half, device='cuda')  # Example tensor
c = torch.empty(10, 30, dtype=torch.half, device='cuda')  # Output tensor

w_scale = 1.0  # Example scale factor
x_scale = 1.0  # Example scale factor

# Call the custom CUDA GEMM operation
gemm_lowbit_ext.gemm_lowbit(a, b, c, w_scale, x_scale)

print(c)  # View the result


Implementation of BitLora!

import torch
from bitnet import BitLora

# Random text tensor
x = torch.randn(1, 12, 200)

# Create an instance of the BitLora model
model = BitLora(in_features=200, out_features=200, rank=4, lora_alpha=1)

# Perform the forward pass
out = model(x)

# Print the shape of the output tensor


import torch
from bitnet import BitMamba

# Create a tensor of size (2, 10) with random values between 0 and 100
x = torch.randint(0, 100, (2, 10))

# Create an instance of the BitMamba model with input size 512, hidden size 100, output size 10, and depth size 6
model = BitMamba(512, 100, 10, 6, return_tokens=True)

# Pass the input tensor through the model and get the output
output = model(x)

# Print the output tensor

# Print the shape of the output tensor


import torch
from bitnet.bit_moe import BitMoE

# Create input tensor
x = torch.randn(2, 4, 8)

# Create BitMoE model with specified input and output dimensions
model = BitMoE(8, 4, 2)

# Forward pass through the model
output = model(x)

# Print the output

1 Bit Vision Transformers

This idea came to me out of nowhere but it seems to be pretty fun, as you can leverage bitlinear for vision tasks for ultra-compression. It would be nice to train this on imagenet if you could make a script, we'll provide the compute. Then the next stage would be to train a join vision language model gpt-4o

import torch
from bitnet import OneBitViT

# Create an instance of the OneBitViT model
v = OneBitViT(

# Generate a random image tensor
img = torch.randn(1, 3, 256, 256)

# Pass the image through the OneBitViT model to get predictions
preds = v(img)  # (1, 1000)

# Print the predictions




Author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Huaijie Wang and Lingxiao Ma and Fan Yang and Ruiping Wang and Yi Wu and Furu Wei},
Title = {BitNet: Scaling 1-bit Transformers for Large Language Models},
Year = {2023},
Eprint = {arXiv:2310.11453},
