Home

Awesome

MX Pytorch Emulation Library

OCP MX Formats Specification

This library provides the capability to emulate MX-compatble formats and bfloat quantization in pytorch, enabling data science exploration for DNNs with different MX formats. The underlying computations are done in float32/bfloat16/fp16 but with values restricted to the representable range of MX-compatible or bfloat data formats.

At high level, the following operations are supported:

The specific data formats (i.e., FP8_e4m3, bfloat16) can be configured using an mx_specs dictionary which is an input to nearly every function. See Spec Configuration.

For simulational speed and accuracy, we provide custom CUDA extensions for basic MX/bfloat quantization. The custom CUDA code is faster, and in the case of MX more numerically accurate than pytorch GPU. See Pytorch CUDA Bugs.

Requirements

We recommend using Nvidia-PyTorch Container

CUDA is required (11.3+ recommended). For Python packages see requirements.txt.

Trademark Notice

Trademarks This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft’s Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies.

Integration Guide

There are two ways to integrate the library into your PyTorch models.

  1. Manually replace each PyTorch module (torch.nn.*) and PyTorch function (torch.nn.function.*) with its equivalent mx library module/function in the model code. This is time-consuming and error-prone, but gives users more precise control over quantization.
  2. Use mx_mapping.inject_pyt_ops to replace PyTorch modules/functions in the global python scope, then build your model.
  3. See MX Option Parsing to setup argument parsing.

The repo contains a PDF guide for manual integration. The examples folders also contains working code samples.

Files

Configuration

Basic Operators

Linear Functions and Layers

Elementwise Functions and Layers

Custom CUDA Extension

Spec Configuration

Operators, functions, and layers provided by this library all take mx_specs dictionary as an argument. The available configuration options are found in specs.py.

MX Options

The first set of options configure the MX-compatible data format. They only apply to the linear layers (linear, matmul, bmm). Please see W and A bits.

In the linear layers (linear, matmul, bmm), linear assumes its two inputs are activaitons and weight. However, matmul and bmm assume both inputs are activations. Please see W and A.

Image Image Image

The next set of options configure bfloat and fp data format. They apply to non-matrix operations like add, mul, sqrt, and exp. Layers such as layernorm, softmax, and GELU are computed using these operations. Only one of bfloat or fp should take a non-zero value. If both bfloat and fp are set to 0, non-matrix operations will be computed in native FP32 format.

The next set of options configure whether quantization is applied to certain layers.

The final set of options configure miscellaneous settings in the library.

Spec Concrete MX-Compatible Formats

The examples below show how to create the mx_specs dictionary and configure it for each concrete MX-compatible format. For simplicity we ignore the backward pass specs like w_elem_format_bp.

# MXFP8_e5m2 matmuls with bfloat16 vector ops, forward pass only
mx_specs = MxSpecs()

mx_specs[‘scale_bits’] = 8
mx_specs[‘w_elem_format’] = 'fp8_e5m2'
mx_specs[‘a_elem_format’] = 'fp8_e5m2'
mx_specs[‘block_size’] = 32
mx_specs[‘bfloat’] = 16
mx_specs[‘custom_cuda’] = True
# MXFP4_e2m1 matmuls with bfloat16 vector ops, forward pass only
mx_specs = MxSpecs()

mx_specs[‘scale_bits’] = 8
mx_specs[‘w_elem_format’] = 'fp4_e2m1'
mx_specs[‘a_elem_format’] = 'fp4_e2m1'
mx_specs[‘block_size’] = 32
mx_specs[‘bfloat’] = 16
mx_specs[‘custom_cuda’] = True
# MXINT8 matmuls with bfloat16 vector ops, forward pass only
mx_specs = MxSpecs()

mx_specs[‘scale_bits’] = 8
mx_specs[‘w_elem_format’] = 'int8'
mx_specs[‘a_elem_format’] = 'int8'
mx_specs[‘block_size’] = 32
mx_specs[‘bfloat’] = 16
mx_specs[‘custom_cuda’] = True

Backward Pass Quantization

For MX quantization of the matmuls on the backward pass, the setting quantize_backprop must be True.

W and A

The Linear layer and the linear function uses a_elem_format for the first input and w_elem_format for the second input. The matmul and bmm functions have a bits_config argument which selects the operation mode (weight x act, act x weight, or act x act)

The library currently assumes that:

MX Option Parsing

The function add_mx_args in specs.py will add an argument for each spec option to an argparse.ArgumentParser object. Note that the defaults of the args added this way will be None. Pass the output of parser.parse_args() to the function get_mx_specs to obtain the correct spec dictionary. Example:

# Example on how to setup and parse MX config flags
parser = argparse.ArgumentParser()
parser = add_mx_args(parser)
args = parser.parse_args()
mx_specs = get_mx_specs(args)

If not set explicitly in the arguments, the backward pass MX formats (w_elem_format_bp, a_elem_format_bp, a_elem_format_bp_ex, a_elem_format_bp_os) will be assigned the values of w_elem_format and a_elem_format.

Boolean flags that are by default True will be added to ArgumentParser with a "--no-" prefix and default False. E.g., the flag --quantize-backprop is default True. It will appear in the parser as --no-quantize-backprop which is default False.

Testing

There is a unit test suite provided with this library under the mx/tests folder. You can run all tests from that folder with python -m pytest .. The pytest pip package is required.

The unit tests have been tested to pass on these configurations:

Numerics

Pytorch CUDA Inaccuracies

The golden reference for numeric results is pytorch CPU. On the GPU, the custom CUDA code is more numerically robust than the (algorithmically equivalent) pytorch code. An example inaccuracy for pytorch code running on Nvidia V100 with Pytorch 1.9.1 + CUDA 11.3:

>>> # We want to shift x to the left by 16 bits
>>> x = torch.tensor([1.], dtype=torch.float32, device='cuda')
>>> e = torch.tensor([16.], dtype=torch.float32, device='cuda')
>>> x * (2**e)
tensor([65535.9961], device='cuda:0')  # should be 65536

In fact, we disabled unit testing for pytorch GPU (for MX only) because this and rounding bugs (0.5 gets rounded down) kept causing mismatches against pytorch CPU.

NaNs/Infs

MX: NaNs are preserved, Infs are either preserved or converted to NaNs. Other values in a vector containing one or more NaNs/Infs have undefined quantization behavior. Bfloat: NaNs/Infs are fully preserved.

Denorms

MX: denorms are supported by default. Bfloat: denorms are supported by default and can be flushed to zero by setting bfloat_subnorms to False.

Rounding

In our code, MX and Bfloat support three rounding modes:

Rounding interacts with the shared exponent in the following way. Example: Consider a hypothetical MXINT3 (1 sign bit, 2 mantissa bits). Let the shared exponent be 2. The quantization grid points are:
[-1.5, -1.0, -0.5, 0.0, +0.5, +1.0, +1.5] With this system, any number in [1.5, 2) are rounded to 1.5. There can't be a 2 or the shared exp would become 2. Even 1.99 gets rounded to 1.5

MXINT 2's Complement

Following the OCP MX Formats Specification, MXINT elements utilize 2's complement with the asymmetric maximum negative representation left unused. The representable values in this encoding are identical to those of sign-magnitude.

CUDA Extensions

The cpp directory includes custom C++ and CUDA implementations of MX library functions. These are automatically JIT-compiled via custom_extensions.py

The following are some references for creating custom extensions for PyTorch:

In the CUDA files, we subsitute the following MX terminology as "block_size" already has a different meaning in CUDA:

The CUDA code was compiled and tested on a machine with the following:

InformationNvidia V100Nvidia A100Nvidia H100
Container Imagenvcr.io/nvidia/pytorch:24.06-py3nvcr.io/nvidia/pytorch:24.06-py3nvcr.io/nvidia/pytorch:24.06-py3
OSUbuntu 20.04Ubuntu 22.04Ubuntu 22.04
Nvidia Driver535.171.04535.183.01550.54.15
CUDA12.512.512.5
cuDNN9.1.0.709.1.0.709.1.0.70
Python3.10.123.10.123.10.12
PyTorch2.4.02.4.02.4.0

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.

Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.