Source code for aimet_torch.onnx

# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause


"""Defines onnx export API"""

import contextlib
import io
import itertools
from packaging import version
import traceback
from typing import Any, Mapping, Tuple, Union, Literal
from pathlib import Path

import numpy as np
import onnx
import torch
from torch.onnx import _constants

import aimet_torch
from aimet_torch.v2.utils import patch_attr
from aimet_torch.common.onnx._utils import (
    _add_onnx_qdq_nodes,
    _convert_version,
    _derive_data_movement_op_encodings,
    _derive_const_rescale_op_output_encodings,
    contains_tensor_type,
)

from .nn import QuantizationMixin
from .quantization import DequantizedTensor
from .quantization.base import EncodingBase
from .quantization.affine import AffineQuantizerBase, AffineEncoding
from .quantization.float import FloatQuantizeDequantize, FloatEncoding
from .v2.quantization.float._finfo import _finfo
from .quantsim import QuantizationSimModel
from .v2.experimental import onnx as _onnx
from .v2.experimental.onnx._export import _get_all_constants


_TORCH_VERSION = version.parse(torch.__version__)
_TORCH_DEFAULT_OPSET = _constants.ONNX_DEFAULT_OPSET
_TORCH_MIN_OPSET = _constants.ONNX_MIN_OPSET
_TORCH_MAX_OPSET = _constants.ONNX_MAX_OPSET

# Allow at least up to opset 21 to enable [u]int16 QDQ export
_AIMET_MAX_OPSET = max(_TORCH_MAX_OPSET, 25)


[docs] @torch.no_grad() def export( model: Union[torch.nn.Module, QuantizationSimModel], args: Union[Tuple[Any, ...], torch.Tensor], f: Union[str, io.BytesIO], *, export_int32_bias: bool = True, prequantize_constants: bool = False, force_activation_as: Literal["unsigned"] | Literal["signed"] | None = "unsigned", **kwargs, ): """ Export :class:`QuantizationSimModel` to onnx model with onnx `QuantizeLinear`_ and `DequantizeLinear`_ embedded in the graph. This function takes the same set of arguments as `torch.onnx.export()`_ Args: model: The model to be exported args: Same as `torch.onnx.export()` f: Same as `torch.onnx.export()` export_int32_bias (bool, optional): If true, generate and export int32 bias encoding on the fly (default: `True`). prequantize_constants (bool, optional): If True, quantized weights will be represented as quantized weight followed by DequantizeLinear. If False, they will be represented as floating point weight followed by QuantizeLinear and DequantizeLinear (default: `False`). force_activation_as (str, optional): Force representing quantized activations as signed or unsigned integers (default: `"unsigned"`) **kwargs: Same as `torch.onnx.export()` .. note:: For robustness, onnx >=1.19 is highly recommended with this API, especially when exporting large models (>2GB). This is due to a known bug in onnx <1.19 version converter. For more information, see https://github.com/onnx/onnx/issues/6529 .. _torch.onnx.export(): https://docs.pytorch.org/docs/stable/onnx_torchscript.html#torch.onnx.export .. _QuantizeLinear: https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html .. _DequantizeLinear: https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html Examples: >>> aimet_torch.onnx.export(sim.model, x, f="model.onnx", ... input_names=["input"], output_names=["output"], ... opset_version=21, dynamo=False, ... export_int32_bias=True) ... >>> import onnxruntime as ort >>> options = ort.SessionOptions() >>> options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL >>> sess = ort.InferenceSession("model.onnx", sess_options=options) >>> onnx_output, = sess.run(None, {"input": x.detach().numpy()}) >>> torch.nn.functional.cosine_similarity(torch.from_numpy(onnx_output), sim.model(x)) tensor([1.0000, 0.9999, 1.0000, ..., 1.0000, 1.0000, 1.0000], grad_fn=<AliasBackward0>) .. image:: ../../images/conv_qdq.onnx.svg :align: center """ if isinstance(model, QuantizationSimModel): model = model.model if not isinstance(model, torch.nn.Module): raise RuntimeError( f"aimet_torch.export only supports torch.nn.Module or QuantizationSimModel; got {type(model)}" ) base_dir = str(Path(str(f)).absolute().parent) Path(base_dir).mkdir(parents=True, exist_ok=True) _check_opset_version(kwargs) _check_unsupported_args(model, force_activation_as, kwargs) _check_non_standard_quantizer(model) target_version = kwargs.pop("opset_version", _TORCH_DEFAULT_OPSET) kwargs["opset_version"] = min(target_version, _TORCH_MAX_OPSET) _assert_minimum_required_opset(model, target_version) with contextlib.ExitStack() as stack: # Unfold all param quantizers to incorporate QuantizeLinear/DequantizeLinear # of those parameters in tracing time stack.enter_context(_temporarily_unfold_param_quantizers(model)) if export_int32_bias: # Temoprarily instantiate int32 bias quantizers stack.enter_context( _concretize_int32_bias_quantizers(model, args, kwargs.get("kwargs")) ) # Export quantize-dequantized weight # pylint: disable=protected-access if force_activation_as in ("unsigned", "signed"): signed = force_activation_as == "signed" stack.enter_context( _temporarily_convert_activation_to(model, signed=signed) ) stack.enter_context(QuantizationSimModel._apply_qdq_to_model_parameters(model)) # Remove [b]float16 quantizers stack.enter_context(_remove_fp16_quantizers(model)) stack.enter_context(_remove_fp16_quantized_parameters(model)) onnx_model, tensor_to_encoding_map = _to_onnx(model, args, f, **kwargs) if _TORCH_MAX_OPSET < target_version: try: onnx_model = _convert_version(onnx_model, target_version) except Exception as e: # pylint: disable=broad-exception-caught f = io.StringIO() traceback.print_exc(file=f) reason = _why_do_i_need_opset21(model) if reason: detail = ( f"torch.onnx.export only supports opset<={_TORCH_MAX_OPSET}, " f"but onnx::QuantizeLinear requires opset>={target_version} for {reason}. " "As a workaround, we tried to torch.onnx.export your model " f"with opset={_TORCH_MAX_OPSET} and convert the onnx model to {target_version}, " "but failed with the following error:\n\n" ) else: detail = "\n\n" msg = ( f"Failed to convert onnx model to {target_version} due to {type(e).__name__}. {detail}" "==============================================================\n" f"{f.getvalue()}" "==============================================================\n\n" ) raise RuntimeError(msg) from e onnx_qdq_model = _to_onnx_qdq( onnx_model, tensor_to_encoding_map, prequantize_constants=prequantize_constants, base_dir=base_dir, ) _remove_intermediate_identity_nodes(onnx_qdq_model) _remove_dangling_nodes_and_initializers(onnx_qdq_model) # Add metadata property to indicate the model is exported by AIMET and its version prop = onnx_qdq_model.metadata_props.add() prop.key = "producer" prop.value = f"aimet-torch {aimet_torch.__version__}" onnx.save(onnx_qdq_model, f)
def _why_do_i_need_opset21(model: torch.nn.Module) -> str: int4 = False int16 = False bq = False for qtzr in model.modules(): if not isinstance(qtzr, AffineQuantizerBase): continue if qtzr.block_size is not None: bq = True if qtzr.bitwidth == 4: int4 = True if qtzr.bitwidth == 16: int16 = True reasons = [] if int4 or int16: reasons.append("int4/int16 quantization") if bq: reasons.append("blockwise quantization") if not reasons: return "" # This should never happen if len(reasons) == 1: return reasons[0] return f"{reasons[0]} and {reasons[1]}" def _assert_minimum_required_opset(model: torch.nn.Module, target_opset: int): if target_opset < 21 and any( qtzr.block_size is not None for qtzr in model.modules() if isinstance(qtzr, AffineQuantizerBase) ): raise RuntimeError( "onnx::QuantizeLinear and DequantizeLinear with per-block are only supported in opset >= 21;" f" got opset={target_opset}" ) if target_opset < 21 and any( qtzr.bitwidth in (4, 16) for qtzr in model.modules() if isinstance(qtzr, AffineQuantizerBase) ): raise RuntimeError( "onnx::QuantizeLinear and DequantizeLinear with INT4/INT16 are only supported in opset >= 21;" f" got opset={target_opset}" ) if target_opset < 13 and any( tuple(qtzr.shape) for qtzr in model.modules() if isinstance(qtzr, AffineQuantizerBase) ): raise RuntimeError( "onnx::QuantizeLinear and DequantizeLinear with per-channel are only supported in opset >= 13;" f" got opset={target_opset}" ) if target_opset < 10: raise RuntimeError( "onnx::QuantizeLinear and DequantizeLinear are only supported in opset >= 10;" f" got opset={target_opset}" ) def _check_opset_version(kwargs): opset_version = kwargs.get("opset_version", _TORCH_DEFAULT_OPSET) if not (_TORCH_MIN_OPSET <= opset_version <= _AIMET_MAX_OPSET): raise ValueError(f"Unsupported ONNX opset version: {opset_version}") def _check_unsupported_args(model, force_activation_as, kwargs): if _TORCH_VERSION >= version.parse("2.0.0") and isinstance( model, torch._dynamo.OptimizedModule, # pylint: disable=protected-access ): if _TORCH_VERSION < version.parse("2.11.0.dev"): raise RuntimeError( "Exporting a torch.compile-d quantsim model is only supported in torch >= 2.11.0. " "For more information, see https://github.com/pytorch/pytorch/issues/171674" ) if force_activation_as not in ("unsigned", "signed", None): raise ValueError( f"force_activation_as must be either 'unsigned', 'signed' or None; got {force_activation_as}" ) dynamo = kwargs.get("dynamo", _TORCH_VERSION >= version.parse("2.9.0")) if dynamo and _TORCH_VERSION < version.parse("2.8.0"): raise RuntimeError("AIMET dynamo export is only supported in torch >= 2.8.0") export_params = kwargs.get("export_params", True) if not export_params: raise NotImplementedError("export_params=False is not supported yet") keep_initializers_as_inputs = kwargs.get("keep_initializers_as_inputs", False) if keep_initializers_as_inputs: raise NotImplementedError( "keep_initializers_as_inputs=True is not supported yet" ) do_constant_folding = kwargs.get("do_constant_folding", True) if not do_constant_folding: raise NotImplementedError("do_constant_folding=False is not supported yet") export_modules_as_functions = kwargs.get("export_modules_as_functions", False) if export_modules_as_functions: raise RuntimeError("export_modules_as_functions=True is not supported") operator_export_type = kwargs.get( "operator_export_type", torch.onnx.OperatorExportTypes.ONNX ) if operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN: raise RuntimeError( "operator_export_type=OperatorExportTypes.ONNX_ATEN is not supported" ) def _check_non_standard_quantizer(model: torch.nn.Module): supported_bitwdiths = (2, 4, 8, 16, 32) for name, qtzr in model.named_modules(): if not isinstance(qtzr, AffineQuantizerBase): continue if qtzr.bitwidth not in supported_bitwdiths: supported_bitwdiths = "/".join(str(b) for b in supported_bitwdiths) raise RuntimeError( f"torch.onnx.export only supports {supported_bitwdiths}-bit integers; " f"got '{name}' with bitwidth={qtzr.bitwidth}" ) def _duplicate_shared_qdq_inputs( onnx_model: onnx.ModelProto, base_dir: str | None ) -> dict[str, str]: """ Duplicate input tensors associated with multiple QDQ nodes to avoid name collision For example, Before: One tensor associated with multiple encodings "conv1.weight" -+--------------> QDQ -> Conv | +--------------> QDQ -> Conv After: Case 1. Insert Identity nodes for each QDQ if QDQ nodes have different encodings "conv1.weight" -+--------------> QDQ -> Conv | +-> Identity --> QDQ -> Conv "conv1.weight_dup_0" Case 2. Otherwise, consolidate multiple QDQs into single QDQ "conv1.weight" -> QDQ -+--------------> Conv | +--------------> Conv """ consumers: dict[str, list[onnx.NodeProto]] = {} aliases: dict[str, str] = {} qdq_nodes_with_shared_input: dict[str, list[onnx.NodeProto]] = {} for node in onnx_model.graph.node: for inp in node.input: consumers.setdefault(inp, []).append(node) for identity in onnx_model.graph.node: if identity.op_type == "Identity": output = identity.output[0] inp = aliases.get(identity.input[0], identity.input[0]) aliases[output] = inp for node in onnx_model.graph.node: if (node.domain, node.op_type) not in ( ("aimet", "quantize_dequantize"), ("aimet", "QuantizeDequantize"), ("aimet", "FloatQuantizeDequantize"), ): continue input_name = aliases.get(node.input[0], node.input[0]) qdq_nodes_with_shared_input.setdefault(input_name, []).append(node) aliases: dict[str, str] = {} producers: dict[str, onnx.NodeProto] = {} constants = _get_all_constants(onnx_model, consumers) for input_name, qdq_nodes in qdq_nodes_with_shared_input.items(): if len(qdq_nodes) <= 1: continue for i, qdq_node in enumerate(qdq_nodes[1:]): if _encoding_equal(qdq_node, qdq_nodes[0], constants, base_dir): # If multiple QDQ nodes have the same encoding, # we can reuse the same QDQ's output for multiple consumers. for consumer in consumers.get(qdq_node.output[0], []): for j, inp in enumerate(consumer.input): if inp == qdq_node.output[0]: consumer.input[j] = qdq_nodes[0].output[0] else: # Duplicate QDQ nodes for each usage # Create an Identity node to duplicate the tensor alias = f"{input_name}_dup_{i}" aliases[alias] = input_name identity_node = onnx.helper.make_node( "Identity", inputs=[input_name], outputs=[alias], name=f"Identity_{input_name}_dup_{i}", ) qdq_node.input[0] = identity_node.output[0] producers[identity_node.output[0]] = identity_node # Insert Identity nodes to the graph in topological order all_nodes = [] for node in onnx_model.graph.node: if node.input and node.input[0] in producers: identity_node = producers[node.input[0]] all_nodes.append(identity_node) all_nodes.append(node) onnx_model.graph.ClearField("node") onnx_model.graph.node.extend(all_nodes) return aliases def _remove_dangling_nodes_and_initializers(onnx_model: onnx.ModelProto): """ Remove nodes and initializers that are not connected to any output. """ producers: dict[str, onnx.NodeProto] = { out: node for node in onnx_model.graph.node for out in node.output } graph_outputs = {output.name for output in onnx_model.graph.output} queue = list(graph_outputs) reachable = set(queue) while queue: tensor_name = queue.pop() producer = producers.get(tensor_name, None) if not producer: continue for inp in producer.input: if inp not in reachable: reachable.add(inp) queue.append(inp) all_nodes = [ node for node in onnx_model.graph.node if any(output in reachable for output in node.output) ] all_initializers = [ initializer for initializer in onnx_model.graph.initializer if initializer.name in reachable ] onnx_model.graph.ClearField("node") onnx_model.graph.ClearField("initializer") onnx_model.graph.node.extend(all_nodes) onnx_model.graph.initializer.extend(all_initializers) def _decouple_back_to_back_qdqs(onnx_model: onnx.ModelProto) -> dict[str, str]: """ Insert Identity node between back-to-back QDQ nodes to avoid name collision when one tensor is associated with multiple encodings. For example, Before: One tensor associated with multiple encodings "conv1.weight" -> QDQ -------------> QDQ -> Conv (float4) (int8) After: Duplicate QDQ nodes for each usage "conv1.weight" -> QDQ -> Identity -> QDQ -> Conv (float4) (int8) """ producers: dict[str, onnx.NodeProto] = { node.output[0]: node for node in onnx_model.graph.node } qdq_nodes = [ node for node in onnx_model.graph.node if f"{node.domain}::{node.op_type}" in ( "aimet::quantize_dequantize", "aimet::QuantizeDequantize", "aimet::FloatQuantizeDequantize", ) ] new_identity_nodes: dict[str, onnx.NodeProto] = {} for qdq_node in qdq_nodes: producer = producers.get(qdq_node.input[0], None) if not producer: continue if f"{producer.domain}::{producer.op_type}" not in ( "aimet::quantize_dequantize", "aimet::QuantizeDequantize", "aimet::FloatQuantizeDequantize", ): continue identity_node = onnx.helper.make_node( "Identity", inputs=[qdq_node.input[0]], outputs=[f"{qdq_node.input[0]}_alias"], name=f"Identity_{qdq_node.input[0]}_alias", ) qdq_node.input[0] = identity_node.output[0] new_identity_nodes[identity_node.input[0]] = identity_node # Insert Identity nodes to the graph in topological order all_nodes = [] for node in onnx_model.graph.node: all_nodes.append(node) if node.output[0] in new_identity_nodes: identity_node = new_identity_nodes[node.output[0]] all_nodes.append(identity_node) onnx_model.graph.ClearField("node") onnx_model.graph.node.extend(all_nodes) def _remove_intermediate_identity_nodes(onnx_model: onnx.ModelProto): """ Remove temporarily added Identity nodes between DequantizeLinear and QuantizeLinear nodes. Before: (weight) -> Q -> DQ -> Identity -> Q -> DQ -> MatMul After: (weight) -> Q -> DQ -------------> Q -> DQ -> MatMul """ producers: dict[str, onnx.NodeProto] = {} consumers: dict[str, list[onnx.NodeProto]] = {} for node in onnx_model.graph.node: producers[node.output[0]] = node for inp in node.input: consumers.setdefault(inp, []).append(node) for identity in onnx_model.graph.node: if identity.op_type != "Identity": continue producer = producers.get(identity.input[0], None) if not (producer and producer.op_type == "DequantizeLinear"): continue if len(consumers.get(identity.output[0], [])) != 1: continue (consumer,) = consumers[identity.output[0]] if consumer.op_type != "QuantizeLinear": continue for consumer in consumers.get(identity.output[0], []): for i, inp in enumerate(consumer.input): if inp == identity.output[0]: consumer.input[i] = identity.input[0] def _encoding_equal( qdq_1: onnx.NodeProto, qdq_2: onnx.NodeProto, constants: dict[str, onnx.TensorProto], base_dir: str | None, ) -> bool: for i, qdq in enumerate([qdq_1, qdq_2]): if (qdq.domain, qdq.op_type) in ( ("aimet", "quantize_dequantize"), ("aimet", "QuantizeDequantize"), ("aimet", "FloatQuantizeDequantize"), ): continue raise ValueError( f"Expected input {i} to be one of aimet::quantize_dequantize, ", "aimet::QuantizeDequantize or aimet::FloatQuantizeDequantize node; " f"got {qdq.domain}::{qdq.op_type}", ) def get_attributes(qdq_node: onnx.NodeProto): finfo = None qmin = None qmax = None block_size = None zero_point_shift = None for attr in qdq_node.attribute: if attr.name == "qmin": qmin = attr.i if attr.name == "qmax": qmax = attr.i if attr.name == "block_size": block_size = tuple(attr.ints) or None if attr.name == "zero_point_shift": zero_point_shift = attr.f if attr.name == "dtype": finfo = _finfo.from_onnx_dtype(attr.i) if finfo: dtype = finfo else: dtype = (qmin, qmax) return dtype, block_size, zero_point_shift def get_scale(qdq_node: onnx.NodeProto) -> np.ndarray: scale_proto = constants.get(qdq_node.input[1]) if scale_proto is None: raise RuntimeError( f"Cannot find constant with name {qdq_node.input[1]} in onnx model" ) return onnx.numpy_helper.to_array(scale_proto, base_dir=base_dir) def get_offset(qdq_node: onnx.NodeProto) -> np.ndarray | None: if len(qdq_node.input) < 3: return None offset_proto = constants.get(qdq_node.input[2]) if offset_proto is None: raise RuntimeError( f"Cannot find constant with name {qdq_node.input[2]} in onnx model" ) offset = onnx.numpy_helper.to_array(offset_proto, base_dir=base_dir) if np.all(offset == 0): offset = None return offset return bool( get_attributes(qdq_1) == get_attributes(qdq_2) and np.all(get_scale(qdq_1) == get_scale(qdq_2)) and np.all(get_offset(qdq_1) == get_offset(qdq_2)) ) def _remove_redundant_qdqs(onnx_model: onnx.ModelProto, base_dir): constants: dict[str, onnx.TensorProto] producers: dict[str, onnx.NodeProto] = {} consumers: dict[str, list[onnx.NodeProto]] = {} for node in onnx_model.graph.node: producers[node.output[0]] = node for inp in node.input: consumers.setdefault(inp, []).append(node) constants = _get_all_constants(onnx_model, consumers) qdq_nodes = [ node for node in onnx_model.graph.node if f"{node.domain}::{node.op_type}" in ( "aimet::quantize_dequantize", "aimet::QuantizeDequantize", "aimet::FloatQuantizeDequantize", ) ] for qdq_node in qdq_nodes: prev_qdq = producers.get(qdq_node.input[0], None) if not prev_qdq: continue if f"{prev_qdq.domain}::{prev_qdq.op_type}" not in ( "aimet::quantize_dequantize", "aimet::QuantizeDequantize", "aimet::FloatQuantizeDequantize", ): continue if not _encoding_equal(qdq_node, prev_qdq, constants, base_dir): continue # Redirect consumers of redundant QDQ nodes to the previous QDQ nodes for consumer in consumers.get(qdq_node.output[0], []): for i, inp in enumerate(consumer.input): if inp == qdq_node.output[0]: consumer.input[i] = qdq_node.input[0] def _to_onnx( model: torch.nn.Module, args: Union[Tuple[Any, ...], torch.Tensor], f: Union[str, io.BytesIO], **kwargs, ) -> Tuple[onnx.ModelProto, dict]: base_dir = str(Path(str(f)).absolute().parent) _onnx.export(model, args, f, **kwargs) onnx_model = onnx.load(f, load_external_data=False) aliases = _duplicate_shared_qdq_inputs(onnx_model, base_dir) _remove_redundant_qdqs(onnx_model, base_dir) _decouple_back_to_back_qdqs(onnx_model) _remove_dangling_nodes_and_initializers(onnx_model) param_names = { f"{layer_name}.{param_name}" for layer_name, layer in model.named_modules() if isinstance(layer, QuantizationMixin) for param_name, quantizer in layer.param_quantizers.items() if quantizer } tensor_to_encoding_map: Mapping[str, Tuple[EncodingBase, bool]] tensor_to_encoding_map = { name: (encoding, name in param_names or aliases.get(name) in param_names) for name, encoding in _onnx.remove_quantization_nodes_from_onnx_graph( onnx_model, base_dir=base_dir, ).items() } encoding_dict = { name: enc.to_qnn_encoding_dict("2.0.0") for name, (enc, _) in tensor_to_encoding_map.items() } derived_encodings = _derive_const_rescale_op_output_encodings( onnx_model, encoding_dict, base_dir ) derived_encodings |= _derive_data_movement_op_encodings( onnx_model, encoding_dict | derived_encodings ) # pylint: disable=protected-access tensor_to_encoding_map |= { name: (AffineEncoding._from_qnn_encoding_dict(encoding, version="2.0.0"), False) for name, encoding in derived_encodings.items() } return onnx_model, tensor_to_encoding_map @contextlib.contextmanager def _concretize_int32_bias_quantizers(model, args, kwargs=None): if not isinstance(args, (tuple, list)): args = (args,) kwargs = kwargs or {} handles = [] orig_bias_quantizers = { qmodule: qmodule.param_quantizers["bias"] for qmodule in model.modules() if isinstance(qmodule, QuantizationMixin) and "bias" in qmodule.param_quantizers and qmodule.bias is not None } try: for qmodule, qtzr in orig_bias_quantizers.items(): if qtzr is not None: # Bias quantizer already exists. # This means the user created bias quantizer by him/herself # In this case, we honor the custom bias quantizer defined by the user continue if "weight" in qmodule.param_quantizers and isinstance( qmodule.param_quantizers["weight"], AffineQuantizerBase ): # pylint: disable=protected-access handle = qmodule.register_forward_hook( type(qmodule)._create_int32_bias_quantizer ) handles.append(handle) try: with contextlib.ExitStack() as stack: for qmodule in model.modules(): if not isinstance(qmodule, QuantizationMixin): continue # bias_scale will be derived as input_scale * weight_scale. # Here, weight_scale is statically available in the weight quantizer, # and we don't need to perform actual weight Q/DQ to capture weight_scale. # Therefore, we temporarily set weight quantizers to "pass-through" mode # to speed up export. for qtzr in qmodule.param_quantizers.children(): stack.enter_context(patch_attr(qtzr, "forward", lambda _: _)) model(*args, **kwargs) finally: for handle in handles: handle.remove() yield finally: for qmodule, qtzr in orig_bias_quantizers.items(): qmodule.param_quantizers["bias"] = qtzr @contextlib.contextmanager def _temporarily_unfold_param_quantizers(model: torch.nn.Module): # pylint: disable=protected-access """ Temporarily re-instantiate param quantizers for ease of export """ with contextlib.ExitStack() as stack: for qmodule in model.modules(): if isinstance(qmodule, QuantizationMixin): stack.enter_context(qmodule._unfold_param_quantizers()) yield @contextlib.contextmanager def _remove_fp16_quantized_parameters(model: torch.nn.Module): """ Temporarily detach [b]float16 encodings from pre-quantized parameters """ # pylint: disable=protected-access with contextlib.ExitStack() as stack: for qmodule in model.modules(): if not isinstance(qmodule, QuantizationMixin): continue for name, param in qmodule.named_parameters(recurse=False): if ( isinstance(param, DequantizedTensor) and isinstance(param.encoding, FloatEncoding) and param.encoding._finfo.to_torch_dtype() in (torch.float16, torch.bfloat16, torch.float32, torch.float64) ): param = torch.nn.Parameter(param.as_subclass(torch.Tensor)) stack.enter_context(patch_attr(qmodule, name, param)) yield @contextlib.contextmanager def _remove_fp16_quantizers(model: torch.nn.Module): """ Temporarily remove [b]float16 quantizers for sim.onnx.export, as sim.onnx.export does NOT support exporting [b]float16 quantizers. """ original_containers = {} try: for qmodule in model.modules(): if not isinstance(qmodule, QuantizationMixin): continue for name, qtzr in qmodule.param_quantizers.items(): if isinstance(qtzr, FloatQuantizeDequantize) and ( qtzr.is_float16() or qtzr.is_bfloat16() ): original_containers[(qmodule.param_quantizers, name)] = qtzr qmodule.param_quantizers[name] = None for i, qtzr in enumerate(qmodule.input_quantizers): if isinstance(qtzr, FloatQuantizeDequantize) and ( qtzr.is_float16() or qtzr.is_bfloat16() ): original_containers[(qmodule.input_quantizers, i)] = qtzr qmodule.input_quantizers[i] = None for i, qtzr in enumerate(qmodule.output_quantizers): if isinstance(qtzr, FloatQuantizeDequantize) and ( qtzr.is_float16() or qtzr.is_bfloat16() ): original_containers[(qmodule.output_quantizers, i)] = qtzr qmodule.output_quantizers[i] = None yield finally: for (container, key), qtzr in original_containers.items(): container[key] = qtzr def _to_onnx_qdq( onnx_model: onnx.ModelProto, tensor_to_encoding_map: Mapping[str, Tuple[EncodingBase, bool]], prequantize_constants: bool, base_dir: str, ) -> onnx.ModelProto: qnn_encodings = { name: encoding.to_qnn_encoding_dict("2.0.0") for name, (encoding, _) in tensor_to_encoding_map.items() } qnn_encodings = { name: encoding for name, encoding in qnn_encodings.items() if encoding } qdq_tensor_names = { fp_tensor_name: f"{fp_tensor_name}_qdq" for fp_tensor_name in qnn_encodings } onnx_opset_version = next( opset.version for opset in onnx_model.opset_import if opset.domain == "" ) # TODO: Support exporting (b)float16 models if contains_tensor_type( onnx_model, (onnx.TensorProto.BFLOAT16, onnx.TensorProto.FLOAT16) ): raise RuntimeError( "Exporting to onnx QDQ is only supported for float32 models." ) float_types = [np.float32 for _ in range(len(qnn_encodings))] # Add onnx QDQ nodes in batch _add_onnx_qdq_nodes( onnx_model, input_names=qnn_encodings.keys(), output_names=qdq_tensor_names.values(), node_name_prefixes=qnn_encodings.keys(), encodings=qnn_encodings.values(), float_types=float_types, onnx_opset=onnx_opset_version, prequantize_constants=prequantize_constants, base_dir=base_dir, ) # Restore model output names from "{output}_qdq" to "{output}" _restore_model_output_names(onnx_model, qdq_tensor_names) return onnx_model def _check_float16_quantizers(module: torch.nn.Module): for qtzr in module.modules(): if isinstance(qtzr, FloatQuantizeDequantize): if not qtzr.is_float16() and not qtzr.is_bfloat16(): msg = " ".join( [ "sim.onnx.export doesn't support exporting floating point encodings", f"except [b]float16. Got {qtzr.bitwidth}-bit float encoding", ] ) raise RuntimeError(msg) def _restore_model_output_names( onnx_model: onnx.ModelProto, qdq_tensor_name_map: Mapping[str, str] ): """ Rename model outputs. Assuming "output" is the model output, before: Softmax ----> output -------> QDQ -------> output_qdq after: Softmax ----> output__ -----> QDQ -------> output Args: onnx_model: onnx model to be modified in-place qdq_tensor_name_map: mapping from original tensor names to QDQ tensor names """ _new_names = { output.name: f"{output.name}__" for output in onnx_model.graph.output if output.name in qdq_tensor_name_map } _new_names.update( { qdq_tensor_name_map[output.name]: output.name for output in onnx_model.graph.output if output.name in qdq_tensor_name_map } ) # At this point, _new_names consists of: # { # "output": "output__", # "output_qdq": "output", # } # # Replacing all tensors accordingly will transform the graph as below: # # before: # Softmax ----> output -------> QDQ -------> output_qdq # after: # Softmax ----> output__ -----> QDQ -------> output for node in onnx_model.graph.node: for i, old_name in enumerate(node.input): new_name = _new_names.get(old_name, None) if new_name is not None: node.input[i] = new_name for i, old_name in enumerate(node.output): new_name = _new_names.get(old_name, None) if new_name is not None: node.output[i] = new_name @torch.no_grad() def _absorb_zero_point_shift(model: torch.nn.Module): """ Absorb zero point shift to weights by promoting bitwidth from 2 to 4. NOTE: This function is only meant for internal testing purpose. """ # pylint: disable=redefined-builtin for qmodule in model.modules(): if not isinstance(qmodule, QuantizationMixin): continue for param_name, qtzr in qmodule.param_quantizers.items(): if not isinstance(qtzr, AffineQuantizerBase): continue if not qtzr.is_initialized(): continue if qtzr.zero_point_shift != 0.5: continue if not qtzr.symmetric: continue weight = getattr(qmodule, param_name) weight_qdq = qtzr(weight).dequantize() weight.copy_(weight_qdq) # weight_qdq ∈ {-1.5 * s, -0.5 * s, 0.5 * s, 1.5 * s } # = { -3 * s/2, -1 * s/2, 1 * s/2, 3 * s/2} new_scale = qtzr.get_scale() / 2 qtzr.bitwidth *= 2 qtzr.zero_point_shift = 0.0 min = new_scale * qtzr.qmin max = new_scale * qtzr.qmax qtzr.set_range(min, max) @contextlib.contextmanager def _temporarily_convert_activation_to(model: torch.nn.Module, signed: bool): """ Temporarily convert all signed activation quantizers to unsigned ones for onnx export. TODO: This is a temporary workaround for QAIRT/HTP bug in handling signed activation encodings. Remove this once the bug is fixed. (https://github.qualcomm.com/qualcomm-ai/aimet/issues/5236) """ target_activation_quantizers = [ qtzr for module in model.modules() if isinstance(module, QuantizationMixin) for qtzr in itertools.chain(module.input_quantizers, module.output_quantizers) if isinstance(qtzr, AffineQuantizerBase) and qtzr.signed != signed ] try: for qtzr in target_activation_quantizers: qtzr.signed = signed yield finally: for qtzr in target_activation_quantizers: qtzr.signed = not signed