Source code for aimet_torch.onnx

# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2025, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
"""Defines onnx export API"""

import copy
import contextlib
import io
import os
import tempfile
import traceback
from typing import Any, Mapping, Tuple, Union

import onnx
import torch
from torch.onnx import _constants

from aimet_common.onnx._utils import _add_onnx_qdq_nodes, _is_grid_preserving_op

from .nn import QuantizationMixin
from .quantization import DequantizedTensor
from .quantization.base import EncodingBase
from .quantization.affine import AffineQuantizerBase, GroupedBlockQuantizeDequantize
from .quantization.float import FloatQuantizeDequantize
from .quantsim import QuantizationSimModel
from .v2.experimental import onnx as _onnx


_TORCH_DEFAULT_OPSET = _constants.ONNX_DEFAULT_OPSET
_TORCH_MIN_OPSET = _constants.ONNX_MIN_OPSET
_TORCH_MAX_OPSET = _constants.ONNX_MAX_OPSET

# Allow at least up to opset 21 to enable [u]int16 QDQ export
_AIMET_MAX_OPSET = max(_TORCH_MAX_OPSET, 21)


[docs] def export( model: Union[torch.nn.Module, QuantizationSimModel], args: Union[Tuple[Any, ...], torch.Tensor], f: Union[str, io.BytesIO], *, export_int32_bias: bool = True, **kwargs, ): """ Export :class:`QuantizationSimModel` to onnx model with onnx `QuantizeLinear`_ and `DequantizeLinear`_ embedded in the graph. This function takes set of same arguments as `torch.onnx.export()`_ Args: model: The model to be exported args: Same as `torch.onnx.export()` f: Same as `torch.onnx.export()` export_int32_bias (bool, optional): If true, generate and export int32 bias encoding on the fly (default: `True`) **kwargs: Same as `torch.onnx.export()` .. note:: Unlike `torch.onnx.export()`, this function allows up to opset 21. to support 4/16-bit quantization only available in opset 21. However, exporting to opset 21 is a beta feature and not fully stable yet. For robustness, opset 20 or lower is recommended whenever possible. .. note:: Dynamo-based export (`dynamo=True`) is not supported yet .. _torch.onnx.export(): https://docs.pytorch.org/docs/stable/onnx_torchscript.html#torch.onnx.export .. _QuantizeLinear: https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html .. _DequantizeLinear: https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html Examples: >>> aimet_torch.onnx.export(sim.model, x, f="model.onnx", ... input_names=["input"], output_names=["output"], ... opset_version=21, export_int32_bias=True) ... >>> import onnxruntime as ort >>> options = ort.SessionOptions() >>> options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL >>> sess = ort.InferenceSession("model.onnx", sess_options=options) >>> onnx_output, = sess.run(None, {"input": x.detach().numpy()}) >>> torch.nn.functional.cosine_similarity(torch.from_numpy(onnx_output), sim.model(x)) tensor([1.0000, 0.9999, 1.0000, ..., 1.0000, 1.0000, 1.0000], grad_fn=<AliasBackward0>) """ if isinstance(model, QuantizationSimModel): model = model.model if not isinstance(model, torch.nn.Module): raise RuntimeError( f"aimet_torch.export only supports torch.nn.Module or QuantizationSimModel; got {type(model)}" ) _check_opset_version(kwargs) _check_unsupported_args(kwargs) _check_non_standard_quantizer(model) target_version = kwargs.pop("opset_version", _TORCH_DEFAULT_OPSET) kwargs["opset_version"] = min(target_version, _TORCH_MAX_OPSET) _assert_minimum_required_opset(model, target_version) with contextlib.ExitStack() as stack: # Unfold all param quantizers to incorporate QuantizeLinear/DequantizeLinear # of those parameters in tracing time stack.enter_context(_temporarily_unfold_param_quantizers(model)) if export_int32_bias: # Temoprarily instantiate int32 bias quantizers stack.enter_context( _concretize_int32_bias_quantizers(model, args, kwargs.get("kwargs")) ) # Export quantize-dequantized weight # pylint: disable=protected-access stack.enter_context(QuantizationSimModel._apply_qdq_to_model_parameters(model)) # Remove [b]float16 quantizers stack.enter_context(_remove_fp16_quantizers(model)) onnx_model, tensor_to_encoding_map = _to_onnx(model, args, **kwargs) if _TORCH_MAX_OPSET < target_version: try: onnx_model = onnx.version_converter.convert_version( onnx_model, target_version ) except Exception as e: # pylint: disable=broad-exception-caught f = io.StringIO() traceback.print_exc(file=f) reason = _why_do_i_need_opset21(model) if reason: detail = ( f"torch.onnx.export only supports opset<={_TORCH_MAX_OPSET}, " f"but onnx::QuantizeLinear requires opset>={target_version} for {reason}. " "As a workaround, we tried to torch.onnx.export your model " f"with opset={_TORCH_MAX_OPSET} and convert the onnx model to {target_version}, " "but failed with the following error:\n\n" ) else: detail = "\n\n" msg = ( f"Failed to convert onnx model to {target_version} due to {type(e).__name__}. {detail}" "==============================================================\n" f"{f.getvalue()}" "==============================================================\n\n" ) raise RuntimeError(msg) from e onnx_qdq_model = _to_onnx_qdq(onnx_model, tensor_to_encoding_map) onnx.save(onnx_qdq_model, f)
def _why_do_i_need_opset21(model: torch.nn.Module) -> str: int4 = False int16 = False bq = False for qtzr in model.modules(): if not isinstance(qtzr, AffineQuantizerBase): continue if qtzr.block_size is not None: bq = True if qtzr.bitwidth == 4: int4 = True if qtzr.bitwidth == 16: int16 = True reasons = [] if int4 or int16: reasons.append("int4/int16 quantization") if bq: reasons.append("blockwise quantization") if not reasons: return "" # This should never happen if len(reasons) == 1: return reasons[0] return f"{reasons[0]} and {reasons[1]}" def _assert_minimum_required_opset(model: torch.nn.Module, target_opset: int): if target_opset < 21 and any( qtzr.block_size is not None for qtzr in model.modules() if isinstance(qtzr, AffineQuantizerBase) ): raise RuntimeError( "onnx::QuantizeLinear and DequantizeLinear with per-block are only supported in opset >= 21;" f" got opset={target_opset}" ) if target_opset < 21 and any( qtzr.bitwidth in (4, 16) for qtzr in model.modules() if isinstance(qtzr, AffineQuantizerBase) ): raise RuntimeError( "onnx::QuantizeLinear and DequantizeLinear with INT4/INT16 are only supported in opset >= 21;" f" got opset={target_opset}" ) if target_opset < 13 and any( tuple(qtzr.shape) for qtzr in model.modules() if isinstance(qtzr, AffineQuantizerBase) ): raise RuntimeError( "onnx::QuantizeLinear and DequantizeLinear with per-channel are only supported in opset >= 13;" f" got opset={target_opset}" ) if target_opset < 10: raise RuntimeError( "onnx::QuantizeLinear and DequantizeLinear are only supported in opset >= 10;" f" got opset={target_opset}" ) def _check_opset_version(kwargs): opset_version = kwargs.get("opset_version", _TORCH_DEFAULT_OPSET) if not (_TORCH_MIN_OPSET <= opset_version <= _AIMET_MAX_OPSET): raise ValueError(f"Unsupported ONNX opset version: {opset_version}") def _check_unsupported_args(kwargs): export_params = kwargs.get("export_params", True) if not export_params: raise NotImplementedError("export_params=False is not supported yet") keep_initializers_as_inputs = kwargs.get("keep_initializers_as_inputs", False) if keep_initializers_as_inputs: raise NotImplementedError( "keep_initializers_as_inputs=True is not supported yet" ) dynamo = kwargs.get("dynamo", False) if dynamo: raise NotImplementedError("dynamo=True is not supported yet") do_constant_folding = kwargs.get("do_constant_folding", True) if not do_constant_folding: raise NotImplementedError("do_constant_folding=False is not supported yet") export_modules_as_functions = kwargs.get("export_modules_as_functions", False) if export_modules_as_functions: raise RuntimeError("export_modules_as_functions=True is not supported") operator_export_type = kwargs.get( "operator_export_type", torch.onnx.OperatorExportTypes.ONNX ) if operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN: raise RuntimeError( "operator_export_type=OperatorExportTypes.ONNX_ATEN is not supported" ) def _check_non_standard_quantizer(model: torch.nn.Module): for name, qtzr in model.named_modules(): if not isinstance(qtzr, AffineQuantizerBase): continue if isinstance(qtzr, GroupedBlockQuantizeDequantize): raise NotImplementedError( "torch.onnx.exoprt doesn't support GroupedBlockQuantizeDequantize (a.k.a LPBQ) yet; " f"got '{name}' of type GroupedBlockQuantizeDequantize" ) if qtzr.bitwidth not in (4, 8, 16, 32): raise RuntimeError( "torch.onnx.exoprt only supports 4/8/16/32-bit integers; " f"got '{name}' with bitwidth={qtzr.bitwidth}" ) def _to_onnx( model: torch.nn.Module, args: Union[Tuple[Any, ...], torch.Tensor], **kwargs ): _check_float16_quantizers(model) with tempfile.TemporaryDirectory() as tmp_dir: tmp_onnx_path = os.path.join(tmp_dir, "quantized_model.onnx") _onnx.export(model, args, tmp_onnx_path, **kwargs) onnx_model = onnx.load(tmp_onnx_path) param_names = { f"{layer_name}.{param_name}" for layer_name, layer in model.named_modules() if isinstance(layer, QuantizationMixin) for param_name, quantizer in layer.param_quantizers.items() if quantizer } tensor_to_encoding_map: Mapping[str, Tuple[EncodingBase, bool]] tensor_to_encoding_map = { name: (encoding, name in param_names) for name, encoding in _onnx.remove_quantization_nodes_from_onnx_graph( onnx_model ).items() } tensor_to_encoding_map |= _derive_data_movement_op_output_encoding( onnx_model, tensor_to_encoding_map ) return onnx_model, tensor_to_encoding_map def _derive_data_movement_op_output_encoding( model: onnx.ModelProto, tensor_to_encoding_map: Mapping[str, Tuple[EncodingBase, bool]], ) -> Mapping[str, Tuple[EncodingBase, bool]]: data_movement_ops = [ node for node in model.graph.node if _is_grid_preserving_op(node.op_type) ] output_encodings = {} for node in data_movement_ops: input_name = node.input[0] output_name = node.output[0] inp_encoding, _ = tensor_to_encoding_map.get(input_name, (None, None)) if not inp_encoding: inp_encoding, _ = output_encodings.get(input_name, (None, None)) if not inp_encoding: # No input encoding to inherit; skip continue if output_name in tensor_to_encoding_map: # Output encoding already exists; skip continue output_encodings[output_name] = (copy.deepcopy(inp_encoding), False) return output_encodings @contextlib.contextmanager def _concretize_int32_bias_quantizers(model, args, kwargs=None): if not isinstance(args, (tuple, list)): args = (args,) kwargs = kwargs or {} handles = [] orig_bias_quantizers = { qmodule: qmodule.param_quantizers["bias"] for qmodule in model.modules() if isinstance(qmodule, QuantizationMixin) and "bias" in qmodule.param_quantizers and qmodule.bias is not None } try: for qmodule, qtzr in orig_bias_quantizers.items(): if qtzr is not None: # Bias quantizer already exists. # This means the user created bias quantizer by him/herself # In this case, we honor the custom bias quantizer defined by the user continue if "weight" in qmodule.param_quantizers and isinstance( qmodule.param_quantizers["weight"], AffineQuantizerBase ): # pylint: disable=protected-access handle = qmodule.register_forward_hook( type(qmodule)._create_int32_bias_quantizer ) handles.append(handle) try: model(*args, **kwargs) finally: for handle in handles: handle.remove() yield finally: for qmodule, qtzr in orig_bias_quantizers.items(): qmodule.param_quantizers["bias"] = qtzr @contextlib.contextmanager def _temporarily_unfold_param_quantizers(model: torch.nn.Module): # pylint: disable=protected-access """ Temporarily re-instantiate param quantizers for ease of export """ modules_with_folded_parameters = [ qmodule for qmodule in model.modules() if isinstance(qmodule, QuantizationMixin) and any(isinstance(param, DequantizedTensor) for param in qmodule.parameters()) ] try: for qmodule in modules_with_folded_parameters: qmodule._unfold_param_quantizers() yield finally: for qmodule in modules_with_folded_parameters: qmodule._fold_param_quantizers() @contextlib.contextmanager def _remove_fp16_quantizers(model: torch.nn.Module): """ Temporarily remove [b]float16 quantizers for sim.onnx.export, as sim.onnx.export does NOT support exporting [b]float16 quantizers. """ original_containers = {} try: for qmodule in model.modules(): if not isinstance(qmodule, QuantizationMixin): continue for name, qtzr in qmodule.param_quantizers.items(): if isinstance(qtzr, FloatQuantizeDequantize) and ( qtzr.is_float16() or qtzr.is_bfloat16() ): original_containers[(qmodule.param_quantizers, name)] = qtzr qmodule.param_quantizers[name] = None for i, qtzr in enumerate(qmodule.input_quantizers): if isinstance(qtzr, FloatQuantizeDequantize) and ( qtzr.is_float16() or qtzr.is_bfloat16() ): original_containers[(qmodule.input_quantizers, i)] = qtzr qmodule.input_quantizers[i] = None for i, qtzr in enumerate(qmodule.output_quantizers): if isinstance(qtzr, FloatQuantizeDequantize) and ( qtzr.is_float16() or qtzr.is_bfloat16() ): original_containers[(qmodule.output_quantizers, i)] = qtzr qmodule.output_quantizers[i] = None yield finally: for (container, key), qtzr in original_containers.items(): container[key] = qtzr def _to_onnx_qdq( onnx_model: onnx.ModelProto, tensor_to_encoding_map: Mapping[str, Tuple[EncodingBase, bool]], ) -> onnx.ModelProto: qnn_encodings = { name: encoding.to_qnn_encoding_dict("2.0.0") for name, (encoding, _) in tensor_to_encoding_map.items() } qnn_encodings = { name: encoding for name, encoding in qnn_encodings.items() if encoding } qdq_tensor_names = { fp_tensor_name: f"{fp_tensor_name}_qdq" for fp_tensor_name in qnn_encodings } onnx_opset_version = next( opset.version for opset in onnx_model.opset_import if opset.domain == "" ) # Add onnx QDQ nodes in batch _add_onnx_qdq_nodes( onnx_model, input_names=qnn_encodings.keys(), output_names=qdq_tensor_names.values(), node_name_prefixes=qnn_encodings.keys(), encodings=qnn_encodings.values(), onnx_opset=onnx_opset_version, ) # Restore model output names from "{output}_qdq" to "{output}" _restore_model_output_names(onnx_model, qdq_tensor_names) return onnx_model def _check_float16_quantizers(module: torch.nn.Module): for qtzr in module.modules(): if isinstance(qtzr, FloatQuantizeDequantize): if not qtzr.is_float16() and not qtzr.is_bfloat16(): msg = " ".join( [ "sim.onnx.export doesn't support exporting floating point encodings", f"except [b]float16. Got {qtzr.bitwidth}-bit float encoding", ] ) raise RuntimeError(msg) def _rename_inputs(onnx_model: onnx.ModelProto, new_names: Mapping[str, str]): for node in onnx_model.graph.node: for i, old_name in enumerate(node.input): new_name = new_names.get(old_name, None) if new_name is not None: node.input[i] = new_name def _rename_outputs(onnx_model: onnx.ModelProto, new_names: Mapping[str, str]): for node in onnx_model.graph.node: for i, old_name in enumerate(node.output): new_name = new_names.get(old_name, None) if new_name is not None: node.output[i] = new_name def _restore_model_output_names( onnx_model: onnx.ModelProto, new_names: Mapping[str, str] ): """ Rename model outputs. Assuming "output" is the model output, before: Softmax ----> output -------> QDQ -------> output_qdq after: Softmax ----> output__ -----> QDQ -------> output """ _new_names = { output.name: f"{output.name}__" for output in onnx_model.graph.output if output.name in new_names } _rename_inputs(onnx_model, _new_names) _new_names.update( { new_names[output.name]: output.name for output in onnx_model.graph.output if output.name in new_names } ) _rename_outputs(onnx_model, _new_names)