Home

Awesome

<!-- # Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of NVIDIA CORPORATION nor the names of its # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -->

License

ONNX Runtime Backend

The Triton backend for the ONNX Runtime. You can learn more about Triton backends in the backend repo. Ask questions or report problems on the issues page.

Use a recent cmake to build and install in a local directory. Typically you will want to build an appropriate ONNX Runtime implementation as part of the build. You do this by specifying a ONNX Runtime version and a Triton container version that you want to use with the backend. You can find the combination of versions used in a particular Triton release in the TRITON_VERSION_MAP at the top of build.py in the branch matching the Triton release you are interested in. For example, to build the ONNX Runtime backend for Triton 23.04, use the versions from TRITON_VERSION_MAP in the r23.04 branch of build.py.

$ mkdir build
$ cd build
$ cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install -DTRITON_BUILD_ONNXRUNTIME_VERSION=1.14.1 -DTRITON_BUILD_CONTAINER_VERSION=23.04 ..
$ make install

The resulting install/backends/onnxruntime directory can be added to a Triton installation as /opt/tritonserver/backends/onnxruntime.

The following required Triton repositories will be pulled and used in the build. By default the "main" branch/tag will be used for each repo but the listed CMake argument can be used to override.

You can add TensorRT support to the ONNX Runtime backend by using -DTRITON_ENABLE_ONNXRUNTIME_TENSORRT=ON. You can add OpenVino support by using -DTRITON_ENABLE_ONNXRUNTIME_OPENVINO=ON -DTRITON_BUILD_ONNXRUNTIME_OPENVINO_VERSION=<version>, where <version> is the OpenVino version to use and should match the TRITON_VERSION_MAP entry as described above. So, to build with both TensorRT and OpenVino support:

$ mkdir build
$ cd build
$ cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install -DTRITON_BUILD_ONNXRUNTIME_VERSION=1.14.1 -DTRITON_BUILD_CONTAINER_VERSION=23.04 -DTRITON_ENABLE_ONNXRUNTIME_TENSORRT=ON -DTRITON_ENABLE_ONNXRUNTIME_OPENVINO=ON -DTRITON_BUILD_ONNXRUNTIME_OPENVINO_VERSION=2021.2.200 ..
$ make install

ONNX Runtime with TensorRT optimization

TensorRT can be used in conjunction with an ONNX model to further optimize the performance. To enable TensorRT optimization you must set the model configuration appropriately. There are several optimizations available for TensorRT, like selection of the compute precision and workspace size. The optimization parameters and their description are as follows.

To explore the usage of more parameters, follow the mapping table below and check ONNX Runtime doc for detail.

Please link to the latest ONNX Runtime binaries in CMake or build from main branch of ONNX Runtime to enable latest options.

Parameter mapping between ONNX Runtime and Triton ONNXRuntime Backend

Key in Triton model configurationValue in Triton model configCorresponding TensorRT EP option in ONNX RuntimeType
max_workspace_size_bytese.g: "4294967296"trt_max_workspace_sizeint
trt_max_partition_iterationse.g: "1000"trt_max_partition_iterationsint
trt_min_subgraph_sizee.g: "1"trt_min_subgraph_sizeint
precision_mode"FP16"trt_fp16_enablebool
precision_mode"INT8"trt_int8_enablebool
int8_calibration_table_nametrt_int8_calibration_table_namestring
int8_use_native_calibration_tablee.g: "1" or "true", "0" or "false"trt_int8_use_native_calibration_tablebool
trt_dla_enabletrt_dla_enablebool
trt_dla_coree.g: "0"trt_dla_coreint
trt_engine_cache_enablee.g: "1" or "true", "0" or "false"trt_engine_cache_enablebool
trt_engine_cache_pathtrt_engine_cache_pathstring
trt_engine_cache_prefixtrt_engine_cache_prefixstring
trt_dump_subgraphse.g: "1" or "true", "0" or "false"trt_dump_subgraphsbool
trt_force_sequential_engine_builde.g: "1" or "true", "0" or "false"trt_force_sequential_engine_buildbool
trt_context_memory_sharing_enablee.g: "1" or "true", "0" or "false"trt_context_memory_sharing_enablebool
trt_layer_norm_fp32_fallbacke.g: "1" or "true", "0" or "false"trt_layer_norm_fp32_fallbackbool
trt_timing_cache_enablee.g: "1" or "true", "0" or "false"trt_timing_cache_enablebool
trt_timing_cache_pathtrt_timing_cache_pathstring
trt_force_timing_cachee.g: "1" or "true", "0" or "false"trt_force_timing_cachebool
trt_detailed_build_loge.g: "1" or "true", "0" or "false"trt_detailed_build_logbool
trt_build_heuristics_enablee.g: "1" or "true", "0" or "false"trt_build_heuristics_enablebool
trt_sparsity_enablee.g: "1" or "true", "0" or "false"trt_sparsity_enablebool
trt_builder_optimization_levele.g: "3"trt_builder_optimization_levelint
trt_auxiliary_streamse.g: "-1"trt_auxiliary_streamsint
trt_tactic_sourcese.g: "-CUDNN,+CUBLAS";trt_tactic_sourcesstring
trt_extra_plugin_lib_pathstrt_extra_plugin_lib_pathsstring
trt_profile_min_shapese.g: "input1:dim1xdimd2...,input2:dim1xdim2...,..."trt_profile_min_shapesstring
trt_profile_max_shapese.g: "input1:dim1xdimd2...,input2:dim1xdim2...,..."trt_profile_max_shapesstring
trt_profile_opt_shapese.g: "input1:dim1xdimd2...,input2:dim1xdim2...,..."trt_profile_opt_shapesstring
trt_cuda_graph_enablee.g: "1" or "true", "0" or "false"trt_cuda_graph_enablebool
trt_dump_ep_context_modele.g: "1" or "true", "0" or "false"trt_dump_ep_context_modelbool
trt_ep_context_file_pathtrt_ep_context_file_pathstring
trt_ep_context_embed_modee.g: "1"trt_ep_context_embed_modeint

The section of model config file specifying these parameters will look like:

.
.
.
optimization { execution_accelerators {
  gpu_execution_accelerator : [ {
    name : "tensorrt"
    parameters { key: "precision_mode" value: "FP16" }
    parameters { key: "max_workspace_size_bytes" value: "1073741824" }}
    parameters { key: "trt_engine_cache_enable" value: "1" }}
  ]
}}
.
.
.

ONNX Runtime with CUDA Execution Provider optimization

When GPU is enabled for ORT, CUDA execution provider is enabled. If TensorRT is also enabled then CUDA EP is treated as a fallback option (only comes into picture for nodes which TensorRT cannot execute). If TensorRT is not enabled then CUDA EP is the primary EP which executes the models. ORT enabled configuring options for CUDA EP to further optimize based on the specific model and user scenarios. There are several optimizations available, please refer to the ONNX Runtime doc for more details. To enable CUDA EP optimization you must set the model configuration appropriately:

optimization { execution_accelerators {
  gpu_execution_accelerator : [ {
    name : "cuda"
    parameters { key: "cudnn_conv_use_max_workspace" value: "0" }
    parameters { key: "use_ep_level_unified_stream" value: "1" }}
  ]
}}

Deprecated Parameters

The way to specify these specific parameters as shown below is deprecated. For backward compatibility, these parameters are still supported. Please use the above method to specify the parameters.

In the model config file, specifying these parameters will look like:

.
.
.
parameters { key: "cudnn_conv_algo_search" value: { string_value: "0" } }
parameters { key: "gpu_mem_limit" value: { string_value: "4294967200" } }
.
.
.

ONNX Runtime with OpenVINO optimization

OpenVINO can be used in conjunction with an ONNX model to further optimize performance. To enable OpenVINO optimization you must set the model configuration as shown below.

.
.
.
optimization { execution_accelerators {
  cpu_execution_accelerator : [ {
    name : "openvino"
  } ]
}}
.
.
.

Other Optimization Options with ONNX Runtime

Details regarding when to use these options and what to expect from them can be found here

Model Config Options

optimization {
  graph : {
    level : 1
}}

parameters { key: "intra_op_thread_count" value: { string_value: "0" } }
parameters { key: "execution_mode" value: { string_value: "0" } }
parameters { key: "inter_op_thread_count" value: { string_value: "0" } }

Command line options

Thread Pools

When intra and inter op threads is set to 0 or a value higher than 1, by default ORT creates threadpool per session. This may not be ideal in every scenario, therefore ORT also supports global threadpools. When global threadpools are enabled ORT creates 1 global threadpool which is shared by every session. Use the backend config to enable global threadpool. When global threadpool is enabled, intra and inter op num threads config should also be provided via backend config. Config values provided in model config will be ignored.

--backend-config=onnxruntime,enable-global-threadpool=<0,1>, --backend-config=onnxruntime,intra_op_thread_count=<int> , --backend-config=onnxruntime,inter_op_thread_count=<int>

Default Max Batch Size

The default-max-batch-size value is used for max_batch_size during Autocomplete when no other value is found. Assuming server was not launched with --disable-auto-complete-config command-line option, the onnxruntime backend will set the max_batch_size of the model to this default value under the following conditions:

  1. Autocomplete has determined the model is capable of batching requests.
  2. max_batch_size is 0 in the model configuration or max_batch_size is omitted from the model configuration.

If max_batch_size > 1 and no scheduler is provided, the dynamic batch scheduler will be used.

--backend-config=onnxruntime,default-max-batch-size=<int>

The default value of default-max-batch-size is 4.