Home

Awesome

<div align="center">

Neural Network Compression Framework (NNCF)

Key FeaturesInstallationDocumentationUsageTutorials and SamplesThird-party integrationModel Zoo

GitHub Release Website Apache License Version 2.0 PyPI Downloads

</div>

Neural Network Compression Framework (NNCF) provides a suite of post-training and training-time algorithms for optimizing inference of neural networks in OpenVINO™ with a minimal accuracy drop.

NNCF is designed to work with models from PyTorch, TorchFX, TensorFlow, ONNX and OpenVINO™.

NNCF provides samples that demonstrate the usage of compression algorithms for different use cases and models. See compression results achievable with the NNCF-powered samples on the NNCF Model Zoo page.

The framework is organized as a Python* package that can be built and used in a standalone mode. The framework architecture is unified to make it easy to add different compression algorithms for both PyTorch and TensorFlow deep learning frameworks.

<a id="key-features"></a>

Key Features

Post-Training Compression Algorithms

Compression algorithmOpenVINOPyTorchTorchFXTensorFlowONNX
Post-Training QuantizationSupportedSupportedExperimentalSupportedSupported
Weights CompressionSupportedSupportedExperimentalNot supportedNot supported
Activation SparsityNot supportedExperimentalNot supportedNot supportedNot supported

Training-Time Compression Algorithms

Compression algorithmPyTorchTensorFlow
Quantization Aware TrainingSupportedSupported
Mixed-Precision QuantizationSupportedNot supported
SparsitySupportedSupported
Filter pruningSupportedSupported
Movement pruningExperimentalNot supported

<a id="documentation"></a>

Documentation

This documentation covers detailed information about NNCF algorithms and functions needed for the contribution to NNCF.

The latest user documentation for NNCF is available here.

NNCF API documentation can be found here.

<a id="usage"></a>

Usage

Post-Training Quantization

The NNCF PTQ is the simplest way to apply 8-bit quantization. To run the algorithm you only need your model and a small (~300 samples) calibration dataset.

OpenVINO is the preferred backend to run PTQ with, while PyTorch, TensorFlow, and ONNX are also supported.

<details open><summary><b>OpenVINO</b></summary>
import nncf
import openvino.runtime as ov
import torch
from torchvision import datasets, transforms

# Instantiate your uncompressed model
model = ov.Core().read_model("/model_path")

# Provide validation part of the dataset to collect statistics needed for the compression algorithm
val_dataset = datasets.ImageFolder("/path", transform=transforms.Compose([transforms.ToTensor()]))
dataset_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1)

# Step 1: Initialize transformation function
def transform_fn(data_item):
    images, _ = data_item
    return images

# Step 2: Initialize NNCF Dataset
calibration_dataset = nncf.Dataset(dataset_loader, transform_fn)
# Step 3: Run the quantization pipeline
quantized_model = nncf.quantize(model, calibration_dataset)
</details> <details><summary><b>PyTorch</b></summary>
import nncf
import torch
from torchvision import datasets, models

# Instantiate your uncompressed model
model = models.mobilenet_v2()

# Provide validation part of the dataset to collect statistics needed for the compression algorithm
val_dataset = datasets.ImageFolder("/path", transform=transforms.Compose([transforms.ToTensor()]))
dataset_loader = torch.utils.data.DataLoader(val_dataset)

# Step 1: Initialize the transformation function
def transform_fn(data_item):
    images, _ = data_item
    return images

# Step 2: Initialize NNCF Dataset
calibration_dataset = nncf.Dataset(dataset_loader, transform_fn)
# Step 3: Run the quantization pipeline
quantized_model = nncf.quantize(model, calibration_dataset)

NOTE If the Post-Training Quantization algorithm does not meet quality requirements you can fine-tune the quantized pytorch model. You can find an example of the Quantization-Aware training pipeline for a pytorch model here.

</details> <details><summary><b>TorchFX</b></summary>
import nncf
import torch.fx
from torchvision import datasets, models
from nncf.torch import disable_patching

# Instantiate your uncompressed model
model = models.mobilenet_v2()

# Provide validation part of the dataset to collect statistics needed for the compression algorithm
val_dataset = datasets.ImageFolder("/path", transform=transforms.Compose([transforms.ToTensor()]))
dataset_loader = torch.utils.data.DataLoader(val_dataset)

# Step 1: Initialize the transformation function
def transform_fn(data_item):
    images, _ = data_item
    return images

# Step 2: Initialize NNCF Dataset
calibration_dataset = nncf.Dataset(dataset_loader, transform_fn)

# Step 3: Export model to TorchFX
input_shape = (1, 3, 224, 224)
with nncf.torch.disable_patching():
    fx_model = torch.export.export_for_training(model, args=(ex_input,)).module()
    # or
    # fx_model = torch.export.export(model, args=(ex_input,)).module()

    # Step 4: Run the quantization pipeline
    quantized_fx_model = nncf.quantize(fx_model, calibration_dataset)

</details> <details><summary><b>TensorFlow</b></summary>
import nncf
import tensorflow as tf
import tensorflow_datasets as tfds

# Instantiate your uncompressed model
model = tf.keras.applications.MobileNetV2()

# Provide validation part of the dataset to collect statistics needed for the compression algorithm
val_dataset = tfds.load("/path", split="validation",
                        shuffle_files=False, as_supervised=True)

# Step 1: Initialize transformation function
def transform_fn(data_item):
    images, _ = data_item
    return images

# Step 2: Initialize NNCF Dataset
calibration_dataset = nncf.Dataset(val_dataset, transform_fn)
# Step 3: Run the quantization pipeline
quantized_model = nncf.quantize(model, calibration_dataset)
</details> <details><summary><b>ONNX</b></summary>
import onnx
import nncf
import torch
from torchvision import datasets

# Instantiate your uncompressed model
onnx_model = onnx.load_model("/model_path")

# Provide validation part of the dataset to collect statistics needed for the compression algorithm
val_dataset = datasets.ImageFolder("/path", transform=transforms.Compose([transforms.ToTensor()]))
dataset_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1)

# Step 1: Initialize transformation function
input_name = onnx_model.graph.input[0].name
def transform_fn(data_item):
    images, _ = data_item
    return {input_name: images.numpy()}

# Step 2: Initialize NNCF Dataset
calibration_dataset = nncf.Dataset(dataset_loader, transform_fn)
# Step 3: Run the quantization pipeline
quantized_model = nncf.quantize(onnx_model, calibration_dataset)
</details>

Training-Time Quantization

Here is an example of Accuracy Aware Quantization pipeline where model weights and compression parameters may be fine-tuned to achieve a higher accuracy.

<details><summary><b>PyTorch</b></summary>
import nncf
import torch
from torchvision import datasets, models

# Instantiate your uncompressed model
model = models.mobilenet_v2()

# Provide validation part of the dataset to collect statistics needed for the compression algorithm
val_dataset = datasets.ImageFolder("/path", transform=transforms.Compose([transforms.ToTensor()]))
dataset_loader = torch.utils.data.DataLoader(val_dataset)

# Step 1: Initialize the transformation function
def transform_fn(data_item):
    images, _ = data_item
    return images

# Step 2: Initialize NNCF Dataset
calibration_dataset = nncf.Dataset(dataset_loader, transform_fn)
# Step 3: Run the quantization pipeline
quantized_model = nncf.quantize(model, calibration_dataset)

# Now use compressed_model as a usual torch.nn.Module
# to fine-tune compression parameters along with the model weights

# Save quantization modules and the quantized model parameters
checkpoint = {
    'state_dict': model.state_dict(),
    'nncf_config': model.nncf.get_config(),
    ... # the rest of the user-defined objects to save
}
torch.save(checkpoint, path_to_checkpoint)

# ...

# Load quantization modules and the quantized model parameters
resuming_checkpoint = torch.load(path_to_checkpoint)
nncf_config = resuming_checkpoint['nncf_config']
state_dict = resuming_checkpoint['state_dict']

quantized_model = nncf.torch.load_from_config(model, nncf_config, example_input)
model.load_state_dict(state_dict)
# ... the rest of the usual PyTorch-powered training pipeline
</details>

Training-Time Compression

Here is an example of Accuracy Aware RB Sparsification pipeline where model weights and compression parameters may be fine-tuned to achieve a higher accuracy.

<details><summary><b>PyTorch</b></summary>
import torch
import nncf.torch  # Important - must be imported before any other external package that depends on torch

from nncf import NNCFConfig
from nncf.torch import create_compressed_model, register_default_init_args

# Instantiate your uncompressed model
from torchvision.models.resnet import resnet50
model = resnet50()

# Load a configuration file to specify compression
nncf_config = NNCFConfig.from_json("resnet50_imagenet_rb_sparsity.json")

# Provide data loaders for compression algorithm initialization, if necessary
import torchvision.datasets as datasets
representative_dataset = datasets.ImageFolder("/path", transform=transforms.Compose([transforms.ToTensor()]))
init_loader = torch.utils.data.DataLoader(representative_dataset)
nncf_config = register_default_init_args(nncf_config, init_loader)

# Apply the specified compression algorithms to the model
compression_ctrl, compressed_model = create_compressed_model(model, nncf_config)

# Now use compressed_model as a usual torch.nn.Module
# to fine-tune compression parameters along with the model weights

# ... the rest of the usual PyTorch-powered training pipeline

# Export to ONNX or .pth when done fine-tuning
compression_ctrl.export_model("compressed_model.onnx")
torch.save(compressed_model.state_dict(), "compressed_model.pth")

NOTE (PyTorch): Due to the way NNCF works within the PyTorch backend, import nncf must be done before any other import of torch in your package or in third-party packages that your code utilizes. Otherwise, the compression may be applied incompletely.

</details> <details><summary><b>Tensorflow</b></summary>
import tensorflow as tf

from nncf import NNCFConfig
from nncf.tensorflow import create_compressed_model, register_default_init_args

# Instantiate your uncompressed model
from tensorflow.keras.applications import ResNet50
model = ResNet50()

# Load a configuration file to specify compression
nncf_config = NNCFConfig.from_json("resnet50_imagenet_rb_sparsity.json")

# Provide dataset for compression algorithm initialization
representative_dataset = tf.data.Dataset.list_files("/path/*.jpeg")
nncf_config = register_default_init_args(nncf_config, representative_dataset, batch_size=1)

# Apply the specified compression algorithms to the model
compression_ctrl, compressed_model = create_compressed_model(model, nncf_config)

# Now use compressed_model as a usual Keras model
# to fine-tune compression parameters along with the model weights

# ... the rest of the usual TensorFlow-powered training pipeline

# Export to Frozen Graph, TensorFlow SavedModel or .h5  when done fine-tuning
compression_ctrl.export_model("compressed_model.pb", save_format="frozen_graph")
</details>

For a more detailed description of NNCF usage in your training code, see this tutorial.

<a id="demos-tutorials-and-samples"></a>

Demos, Tutorials and Samples

For a quicker start with NNCF-powered compression, try sample notebooks and scripts presented below.

Jupyter* Notebook Tutorials and Demos

Ready-to-run Jupyter* notebook tutorials and demos are available to explain and display NNCF compression algorithms for optimizing models for inference with the OpenVINO Toolkit:

Notebook Tutorial NameCompression AlgorithmBackendDomain
BERT Quantization<br>ColabPost-Training QuantizationOpenVINONLP
MONAI Segmentation Model Quantization<br>BinderPost-Training QuantizationOpenVINOSegmentation
PyTorch Model QuantizationPost-Training QuantizationPyTorchImage Classification
Quantization with Accuracy ControlPost-Training Quantization with Accuracy ControlOpenVINOSpeech-to-Text,<br>Object Detection
PyTorch Training-Time CompressionTraining-Time CompressionPyTorchImage Classification
TensorFlow Training-Time CompressionTraining-Time CompressionTensorflowImage Classification
Joint Pruning, Quantization and Distillation for BERTJoint Pruning, Quantization and DistillationOpenVINONLP

A list of notebooks demonstrating OpenVINO conversion and inference together with NNCF compression for models from various domains:

Demo ModelCompression AlgorithmBackendDomain
YOLOv8<br>ColabPost-Training QuantizationOpenVINOObject Detection,<br>KeyPoint Detection,<br>Instance Segmentation
EfficientSAMPost-Training QuantizationOpenVINOImage Segmentation
Segment Anything ModelPost-Training QuantizationOpenVINOImage Segmentation
OneFormerPost-Training QuantizationOpenVINOImage Segmentation
InstructPix2PixPost-Training QuantizationOpenVINOImage-to-Image
CLIPPost-Training QuantizationOpenVINOImage-to-Text
BLIPPost-Training QuantizationOpenVINOImage-to-Text
Segmind-VegaRTPost-Training QuantizationOpenVINOText-to-Image
Latent Consistency ModelPost-Training QuantizationOpenVINOText-to-Image
WürstchenPost-Training QuantizationOpenVINOText-to-Image
ControlNet QR Code MonsterPost-Training QuantizationOpenVINOText-to-Image
SDXL-turboPost-Training QuantizationOpenVINOText-to-Image,<br>Image-to-Image
ImageBindPost-Training QuantizationOpenVINOMulti-Modal Retrieval
Distil-WhisperPost-Training QuantizationOpenVINOSpeech-to-Text
Whisper<br>ColabPost-Training QuantizationOpenVINOSpeech-to-Text
MMS Speech RecognitionPost-Training QuantizationOpenVINOSpeech-to-Text
Grammar Error CorrectionPost-Training QuantizationOpenVINONLP, Grammar Correction
LLM Instruction FollowingWeight CompressionOpenVINONLP, Instruction Following
Dolly 2.0Weight CompressionOpenVINONLP, Instruction Following
LLM Chat BotsWeight CompressionOpenVINONLP, Chat Bot

Post-Training Quantization Examples

Compact scripts demonstrating quantization and corresponding inference speed boost:

Example NameCompression AlgorithmBackendDomain
OpenVINO MobileNetV2Post-Training QuantizationOpenVINOImage Classification
OpenVINO YOLOv8Post-Training QuantizationOpenVINOObject Detection
OpenVINO YOLOv8 QwAСPost-Training Quantization with Accuracy ControlOpenVINOObject Detection
OpenVINO Anomaly ClassificationPost-Training Quantization with Accuracy ControlOpenVINOAnomaly Classification
PyTorch MobileNetV2Post-Training QuantizationPyTorchImage Classification
PyTorch SSDPost-Training QuantizationPyTorchObject Detection
TorchFX Resnet18Post-Training QuantizationTorchFXImage Classification
TensorFlow MobileNetV2Post-Training QuantizationTensorFlowImage Classification
ONNX MobileNetV2Post-Training QuantizationONNXImage Classification

Training-Time Compression Examples

Examples of full pipelines including compression, training, and inference for classification, detection, and segmentation tasks:

Example NameCompression AlgorithmBackendDomain
PyTorch Image ClassificationTraining-Time CompressionPyTorchImage Classification
PyTorch Object DetectionTraining-Time CompressionPyTorchObject Detection
PyTorch Semantic SegmentationTraining-Time CompressionPyTorchSemantic Segmentation
TensorFlow Image ClassificationTraining-Time CompressionTensorFlowImage Classification
TensorFlow Object DetectionTraining-Time CompressionTensorFlowObject Detection
TensorFlow Instance SegmentationTraining-Time CompressionTensorFlowInstance Segmentation

<a id="third-party-repository-integration"></a>

Third-party repository integration

NNCF may be easily integrated into training/evaluation pipelines of third-party repositories.

Used by

<a id="installation-guide"></a>

Installation Guide

For detailed installation instructions, refer to the Installation guide.

NNCF can be installed as a regular PyPI package via pip:

pip install nncf

NNCF is also available via conda:

conda install -c conda-forge nncf

System requirements

This repository is tested on Python* 3.10.14, PyTorch* 2.5.0 (NVidia CUDA* Toolkit 12.4) and TensorFlow* 2.12.1 (NVidia CUDA* Toolkit 11.8).

NNCF Compressed NNCF Model Zoo

List of models and compression results for them can be found at our NNCF Model Zoo page.

Citing

@article{kozlov2020neural,
    title =   {Neural network compression framework for fast model inference},
    author =  {Kozlov, Alexander and Lazarevich, Ivan and Shamporov, Vasily and Lyalyushkin, Nikolay and Gorbachev, Yury},
    journal = {arXiv preprint arXiv:2002.08679},
    year =    {2020}
}

Contributing Guide

Refer to the CONTRIBUTING.md file for guidelines on contributions to the NNCF repository.

Useful links

Telemetry

NNCF as part of the OpenVINO™ toolkit collects anonymous usage data for the purpose of improving OpenVINO™ tools. You can opt-out at any time by running the following command in the Python environment where you have NNCF installed:

opt_in_out --opt_out

More information available on OpenVINO telemetry.