# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
"""Defines onnx export API"""
import contextlib
import io
import itertools
from aimet_torch.v2.utils import patch_attr
from packaging import version
import traceback
from typing import Any, Mapping, Tuple, Union, Literal
from pathlib import Path
import numpy as np
import onnx
import torch
from torch.onnx import _constants
from aimet_torch.common.onnx._utils import (
_add_onnx_qdq_nodes,
_convert_version,
_derive_data_movement_op_encodings,
contains_tensor_type,
)
from .nn import QuantizationMixin
from .quantization import DequantizedTensor
from .quantization.base import EncodingBase
from .quantization.affine import AffineQuantizerBase, AffineEncoding
from .quantization.float import FloatQuantizeDequantize, FloatEncoding
from .v2.quantization.float._finfo import _finfo
from .quantsim import QuantizationSimModel
from .v2.experimental import onnx as _onnx
_TORCH_VERSION = version.parse(torch.__version__)
_TORCH_DEFAULT_OPSET = _constants.ONNX_DEFAULT_OPSET
_TORCH_MIN_OPSET = _constants.ONNX_MIN_OPSET
_TORCH_MAX_OPSET = _constants.ONNX_MAX_OPSET
# Allow at least up to opset 21 to enable [u]int16 QDQ export
_AIMET_MAX_OPSET = max(_TORCH_MAX_OPSET, 23)
[docs]
@torch.no_grad()
def export(
model: Union[torch.nn.Module, QuantizationSimModel],
args: Union[Tuple[Any, ...], torch.Tensor],
f: Union[str, io.BytesIO],
*,
export_int32_bias: bool = True,
prequantize_constants: bool = False,
force_activation_as: Literal["unsigned"] | Literal["signed"] | None = "unsigned",
**kwargs,
):
"""
Export :class:`QuantizationSimModel` to onnx model with
onnx `QuantizeLinear`_ and `DequantizeLinear`_ embedded in the graph.
This function takes the same set of arguments as `torch.onnx.export()`_
Args:
model: The model to be exported
args: Same as `torch.onnx.export()`
f: Same as `torch.onnx.export()`
export_int32_bias (bool, optional):
If true, generate and export int32 bias encoding on the fly (default: `True`).
prequantize_constants (bool):
If True, quantized weights will be represented as quantized weight followed by DequantizeLinear.
If False, they will be represented as floating point weight followed by
QuantizeLinear and DequantizeLinear (default: `False`).
force_activation_as (str):
Force representing quantized activations as signed or unsigned integers (default: `"unsigned"`)
**kwargs: Same as `torch.onnx.export()`
.. note::
For robustness, onnx >=1.19 is highly recommended with this API,
especially when exporting large models (>2GB).
This is due to a known bug in onnx <1.19 version converter.
For more information, see https://github.com/onnx/onnx/issues/6529
.. _torch.onnx.export(): https://docs.pytorch.org/docs/stable/onnx_torchscript.html#torch.onnx.export
.. _QuantizeLinear: https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html
.. _DequantizeLinear: https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html
Examples:
>>> aimet_torch.onnx.export(sim.model, x, f="model.onnx",
... input_names=["input"], output_names=["output"],
... opset_version=21, dynamo=False,
... export_int32_bias=True)
...
>>> import onnxruntime as ort
>>> options = ort.SessionOptions()
>>> options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
>>> sess = ort.InferenceSession("model.onnx", sess_options=options)
>>> onnx_output, = sess.run(None, {"input": x.detach().numpy()})
>>> torch.nn.functional.cosine_similarity(torch.from_numpy(onnx_output), sim.model(x))
tensor([1.0000, 0.9999, 1.0000, ..., 1.0000, 1.0000, 1.0000],
grad_fn=<AliasBackward0>)
.. image:: ../../images/conv_qdq.onnx.svg
:align: center
"""
if isinstance(model, QuantizationSimModel):
model = model.model
if not isinstance(model, torch.nn.Module):
raise RuntimeError(
f"aimet_torch.export only supports torch.nn.Module or QuantizationSimModel; got {type(model)}"
)
base_dir = str(Path(str(f)).absolute().parent)
_check_opset_version(kwargs)
_check_unsupported_args(model, force_activation_as, kwargs)
_check_non_standard_quantizer(model)
target_version = kwargs.pop("opset_version", _TORCH_DEFAULT_OPSET)
kwargs["opset_version"] = min(target_version, _TORCH_MAX_OPSET)
_assert_minimum_required_opset(model, target_version)
with contextlib.ExitStack() as stack:
# Unfold all param quantizers to incorporate QuantizeLinear/DequantizeLinear
# of those parameters in tracing time
stack.enter_context(_temporarily_unfold_param_quantizers(model))
if export_int32_bias:
# Temoprarily instantiate int32 bias quantizers
stack.enter_context(
_concretize_int32_bias_quantizers(model, args, kwargs.get("kwargs"))
)
# Export quantize-dequantized weight
# pylint: disable=protected-access
if force_activation_as in ("unsigned", "signed"):
signed = force_activation_as == "signed"
stack.enter_context(
_temporarily_convert_activation_to(model, signed=signed)
)
stack.enter_context(QuantizationSimModel._apply_qdq_to_model_parameters(model))
# Remove [b]float16 quantizers
stack.enter_context(_remove_fp16_quantizers(model))
stack.enter_context(_remove_fp16_quantized_parameters(model))
onnx_model, tensor_to_encoding_map = _to_onnx(model, args, f, **kwargs)
if _TORCH_MAX_OPSET < target_version:
try:
onnx_model = _convert_version(onnx_model, target_version)
except Exception as e: # pylint: disable=broad-exception-caught
f = io.StringIO()
traceback.print_exc(file=f)
reason = _why_do_i_need_opset21(model)
if reason:
detail = (
f"torch.onnx.export only supports opset<={_TORCH_MAX_OPSET}, "
f"but onnx::QuantizeLinear requires opset>={target_version} for {reason}. "
"As a workaround, we tried to torch.onnx.export your model "
f"with opset={_TORCH_MAX_OPSET} and convert the onnx model to {target_version}, "
"but failed with the following error:\n\n"
)
else:
detail = "\n\n"
msg = (
f"Failed to convert onnx model to {target_version} due to {type(e).__name__}. {detail}"
"==============================================================\n"
f"{f.getvalue()}"
"==============================================================\n\n"
)
raise RuntimeError(msg) from e
onnx_qdq_model = _to_onnx_qdq(
onnx_model,
tensor_to_encoding_map,
prequantize_constants=prequantize_constants,
base_dir=base_dir,
)
_remove_intermediate_identity_nodes(onnx_qdq_model)
onnx.save(onnx_qdq_model, f)
def _why_do_i_need_opset21(model: torch.nn.Module) -> str:
int4 = False
int16 = False
bq = False
for qtzr in model.modules():
if not isinstance(qtzr, AffineQuantizerBase):
continue
if qtzr.block_size is not None:
bq = True
if qtzr.bitwidth == 4:
int4 = True
if qtzr.bitwidth == 16:
int16 = True
reasons = []
if int4 or int16:
reasons.append("int4/int16 quantization")
if bq:
reasons.append("blockwise quantization")
if not reasons:
return "" # This should never happen
if len(reasons) == 1:
return reasons[0]
return f"{reasons[0]} and {reasons[1]}"
def _assert_minimum_required_opset(model: torch.nn.Module, target_opset: int):
if target_opset < 21 and any(
qtzr.block_size is not None
for qtzr in model.modules()
if isinstance(qtzr, AffineQuantizerBase)
):
raise RuntimeError(
"onnx::QuantizeLinear and DequantizeLinear with per-block are only supported in opset >= 21;"
f" got opset={target_opset}"
)
if target_opset < 21 and any(
qtzr.bitwidth in (4, 16)
for qtzr in model.modules()
if isinstance(qtzr, AffineQuantizerBase)
):
raise RuntimeError(
"onnx::QuantizeLinear and DequantizeLinear with INT4/INT16 are only supported in opset >= 21;"
f" got opset={target_opset}"
)
if target_opset < 13 and any(
tuple(qtzr.shape)
for qtzr in model.modules()
if isinstance(qtzr, AffineQuantizerBase)
):
raise RuntimeError(
"onnx::QuantizeLinear and DequantizeLinear with per-channel are only supported in opset >= 13;"
f" got opset={target_opset}"
)
if target_opset < 10:
raise RuntimeError(
"onnx::QuantizeLinear and DequantizeLinear are only supported in opset >= 10;"
f" got opset={target_opset}"
)
def _check_opset_version(kwargs):
opset_version = kwargs.get("opset_version", _TORCH_DEFAULT_OPSET)
if not (_TORCH_MIN_OPSET <= opset_version <= _AIMET_MAX_OPSET):
raise ValueError(f"Unsupported ONNX opset version: {opset_version}")
def _check_unsupported_args(model, force_activation_as, kwargs):
if _TORCH_VERSION >= version.parse("2.0.0") and isinstance(
model,
torch._dynamo.OptimizedModule, # pylint: disable=protected-access
):
if _TORCH_VERSION < version.parse("2.11.0.dev"):
raise RuntimeError(
"Exporting a torch.compile-d quantsim model is only supported in torch >= 2.11.0. "
"For more information, see https://github.com/pytorch/pytorch/issues/171674"
)
if force_activation_as not in ("unsigned", "signed", None):
raise ValueError(
f"force_activation_as must be either 'unsigned', 'signed' or None; got {force_activation_as}"
)
dynamo = kwargs.get("dynamo", _TORCH_VERSION >= version.parse("2.9.0"))
if dynamo and _TORCH_VERSION < version.parse("2.8.0"):
raise RuntimeError("AIMET dynamo export is only supported in torch >= 2.8.0")
export_params = kwargs.get("export_params", True)
if not export_params:
raise NotImplementedError("export_params=False is not supported yet")
keep_initializers_as_inputs = kwargs.get("keep_initializers_as_inputs", False)
if keep_initializers_as_inputs:
raise NotImplementedError(
"keep_initializers_as_inputs=True is not supported yet"
)
do_constant_folding = kwargs.get("do_constant_folding", True)
if not do_constant_folding:
raise NotImplementedError("do_constant_folding=False is not supported yet")
export_modules_as_functions = kwargs.get("export_modules_as_functions", False)
if export_modules_as_functions:
raise RuntimeError("export_modules_as_functions=True is not supported")
operator_export_type = kwargs.get(
"operator_export_type", torch.onnx.OperatorExportTypes.ONNX
)
if operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN:
raise RuntimeError(
"operator_export_type=OperatorExportTypes.ONNX_ATEN is not supported"
)
def _check_non_standard_quantizer(model: torch.nn.Module):
for name, qtzr in model.named_modules():
if not isinstance(qtzr, AffineQuantizerBase):
continue
if qtzr.bitwidth not in (4, 8, 16, 32):
raise RuntimeError(
"torch.onnx.export only supports 4/8/16/32-bit integers; "
f"got '{name}' with bitwidth={qtzr.bitwidth}"
)
def _duplicate_shared_qdq_inputs(onnx_model: onnx.ModelProto) -> dict[str, str]:
"""
Duplicate input tensors associated with multiple QDQ nodes to avoid name collision
For example,
Before: One tensor associated with multiple encodings
"conv1.weight" -+--------------> QDQ -> Conv
|
+--------------> QDQ -> Conv
After: Duplicate QDQ nodes for each usage
"conv1.weight_dup_0"
↓
"conv1.weight" -+-- Identity --> QDQ -> Conv
|
+-> Identity --> QDQ -> Conv
↑
"conv1.weight_dup_1"
"""
consumers: dict[str, list[onnx.NodeProto]] = {}
for node in onnx_model.graph.node:
if f"{node.domain}::{node.op_type}" in (
"aimet::quantize_dequantize",
"aimet::QuantizeDequantize",
"::QuantizeLinear",
"::DequantizeLinear",
):
consumers.setdefault(node.input[0], []).append(node)
aliases: dict[str, str] = {}
producers: dict[str, onnx.NodeProto] = {}
for input_name, qdq_nodes in consumers.items():
if len(qdq_nodes) <= 1:
continue
# Duplicate QDQ nodes for each usage
for i, qdq_node in enumerate(qdq_nodes):
# Create an Identity node to duplicate the tensor
alias = f"{input_name}_dup_{i}"
aliases[alias] = input_name
identity_node = onnx.helper.make_node(
"Identity",
inputs=[input_name],
outputs=[alias],
name=f"Identity_{input_name}_dup_{i}",
)
qdq_node.input[0] = identity_node.output[0]
producers[identity_node.output[0]] = identity_node
# Insert Identity nodes to the graph in topological order
all_nodes = []
for node in onnx_model.graph.node:
if node.input and node.input[0] in producers:
identity_node = producers[node.input[0]]
all_nodes.append(identity_node)
all_nodes.append(node)
onnx_model.graph.ClearField("node")
onnx_model.graph.node.extend(all_nodes)
return aliases
def _decouple_back_to_back_qdqs(onnx_model: onnx.ModelProto) -> dict[str, str]:
"""
Insert Identity node between back-to-back QDQ nodes to avoid name collision
when one tensor is associated with multiple encodings.
For example,
Before: One tensor associated with multiple encodings
"conv1.weight" -> QDQ -------------> QDQ -> Conv
(float4) (int8)
After: Duplicate QDQ nodes for each usage
"conv1.weight" -> QDQ -> Identity -> QDQ -> Conv
(float4) (int8)
"""
producers: dict[str, onnx.NodeProto] = {
node.output[0]: node for node in onnx_model.graph.node
}
qdq_nodes = [
node
for node in onnx_model.graph.node
if f"{node.domain}::{node.op_type}"
in (
"aimet::quantize_dequantize",
"aimet::QuantizeDequantize",
"aimet::FloatQuantizeDequantize",
)
]
new_identity_nodes: dict[str, onnx.NodeProto] = {}
for qdq_node in qdq_nodes:
producer = producers.get(qdq_node.input[0], None)
if not producer:
continue
if f"{producer.domain}::{producer.op_type}" not in (
"aimet::quantize_dequantize",
"aimet::QuantizeDequantize",
"aimet::FloatQuantizeDequantize",
):
continue
identity_node = onnx.helper.make_node(
"Identity",
inputs=[qdq_node.input[0]],
outputs=[f"{qdq_node.input[0]}_alias"],
name=f"Identity_{qdq_node.input[0]}_alias",
)
qdq_node.input[0] = identity_node.output[0]
new_identity_nodes[identity_node.input[0]] = identity_node
# Insert Identity nodes to the graph in topological order
all_nodes = []
for node in onnx_model.graph.node:
all_nodes.append(node)
if node.output[0] in new_identity_nodes:
identity_node = new_identity_nodes[node.output[0]]
all_nodes.append(identity_node)
onnx_model.graph.ClearField("node")
onnx_model.graph.node.extend(all_nodes)
def _remove_intermediate_identity_nodes(onnx_model: onnx.ModelProto):
"""
Remove temporarily added Identity nodes between DequantizeLinear and QuantizeLinear nodes.
Before:
(weight) -> Q -> DQ -> Identity -> Q -> DQ -> MatMul
After:
(weight) -> Q -> DQ -------------> Q -> DQ -> MatMul
"""
producers: dict[str, onnx.NodeProto] = {}
consumers: dict[str, list[onnx.NodeProto]] = {}
for node in onnx_model.graph.node:
producers[node.output[0]] = node
for inp in node.input:
consumers.setdefault(inp, []).append(node)
removable: dict[str, onnx.NodeProto] = {}
for identity in onnx_model.graph.node:
if identity.op_type != "Identity":
continue
producer = producers.get(identity.input[0], None)
if not (producer and producer.op_type == "DequantizeLinear"):
continue
if len(consumers.get(identity.output[0], [])) != 1:
continue
(consumer,) = consumers[identity.output[0]]
if consumer.op_type != "QuantizeLinear":
continue
removable[identity.name] = identity
for identity in removable.values():
for consumer in consumers.get(identity.output[0], []):
for i, inp in enumerate(consumer.input):
if inp == identity.output[0]:
consumer.input[i] = identity.input[0]
# Remove removable Identity nodes
all_nodes = [node for node in onnx_model.graph.node if node.name not in removable]
onnx_model.graph.ClearField("node")
onnx_model.graph.node.extend(all_nodes)
def _remove_redundant_qdqs(onnx_model: onnx.ModelProto, base_dir):
from onnx.external_data_helper import _get_all_tensors
constants = {
const.name: const for const in _get_all_tensors(onnx_model) if const.name
}
constants |= {
const_node.output[0]: attr.t
for const_node in onnx_model.graph.node
if const_node.op_type == "Constant"
for attr in const_node.attribute
if attr.HasField("t")
}
producers: dict[str, onnx.NodeProto] = {}
consumers: dict[str, list[onnx.NodeProto]] = {}
for node in onnx_model.graph.node:
producers[node.output[0]] = node
for inp in node.input:
consumers.setdefault(inp, []).append(node)
qdq_nodes = [
node
for node in onnx_model.graph.node
if f"{node.domain}::{node.op_type}"
in (
"aimet::quantize_dequantize",
"aimet::QuantizeDequantize",
"aimet::FloatQuantizeDequantize",
)
]
removable: dict[str, onnx.NodeProto] = {}
def get_attributes(qdq_node: onnx.NodeProto):
finfo = None
qmin = None
qmax = None
block_size = None
zero_point_shift = None
for attr in qdq_node.attribute:
if attr.name == "qmin":
qmin = attr.i
if attr.name == "qmax":
qmax = attr.i
if attr.name == "block_size":
block_size = tuple(attr.ints) or None
if attr.name == "zero_point_shift":
zero_point_shift = attr.f
if attr.name == "dtype":
finfo = _finfo.from_onnx_dtype(attr.i)
if finfo:
dtype = finfo
else:
dtype = (qmin, qmax)
return dtype, block_size, zero_point_shift
def get_scale(qdq_node: onnx.NodeProto) -> np.ndarray:
scale_proto = constants.get(qdq_node.input[1])
if scale_proto is None:
raise RuntimeError(
f"Cannot find constant with name {qdq_node.input[1]} in onnx model"
)
return onnx.numpy_helper.to_array(scale_proto, base_dir=base_dir)
def get_offset(qdq_node: onnx.NodeProto) -> np.ndarray | None:
if len(qdq_node.input) < 3:
return None
offset_proto = constants.get(qdq_node.input[2])
if offset_proto is None:
raise RuntimeError(
f"Cannot find constant with name {qdq_node.input[2]} in onnx model"
)
offset = onnx.numpy_helper.to_array(offset_proto, base_dir=base_dir)
if np.all(offset == 0):
offset = None
return offset
for qdq_node in qdq_nodes:
prev_qdq = producers.get(qdq_node.input[0], None)
if not prev_qdq:
continue
if f"{prev_qdq.domain}::{prev_qdq.op_type}" not in (
"aimet::quantize_dequantize",
"aimet::QuantizeDequantize",
"aimet::FloatQuantizeDequantize",
):
continue
if (
get_attributes(qdq_node) == get_attributes(prev_qdq)
and np.all(get_scale(qdq_node) == get_scale(prev_qdq))
and np.all(get_offset(qdq_node) == get_offset(prev_qdq))
):
removable[qdq_node.name] = qdq_node
# Redirect consumers of redundant QDQ nodes to the previous QDQ nodes
for qdq_node in removable.values():
for consumer in consumers.get(qdq_node.output[0], []):
for i, inp in enumerate(consumer.input):
if inp == qdq_node.output[0]:
consumer.input[i] = qdq_node.input[0]
# Remove removable qdq nodes
all_nodes = [node for node in onnx_model.graph.node if node.name not in removable]
onnx_model.graph.ClearField("node")
onnx_model.graph.node.extend(all_nodes)
def _to_onnx(
model: torch.nn.Module,
args: Union[Tuple[Any, ...], torch.Tensor],
f: Union[str, io.BytesIO],
**kwargs,
) -> Tuple[onnx.ModelProto, dict]:
base_dir = str(Path(str(f)).absolute().parent)
_onnx.export(model, args, f, **kwargs)
onnx_model = onnx.load(f, load_external_data=False)
aliases = _duplicate_shared_qdq_inputs(onnx_model)
_remove_redundant_qdqs(onnx_model, base_dir)
_decouple_back_to_back_qdqs(onnx_model)
param_names = {
f"{layer_name}.{param_name}"
for layer_name, layer in model.named_modules()
if isinstance(layer, QuantizationMixin)
for param_name, quantizer in layer.param_quantizers.items()
if quantizer
}
tensor_to_encoding_map: Mapping[str, Tuple[EncodingBase, bool]]
tensor_to_encoding_map = {
name: (encoding, name in param_names or aliases.get(name) in param_names)
for name, encoding in _onnx.remove_quantization_nodes_from_onnx_graph(
onnx_model,
base_dir=base_dir,
).items()
}
derived_encodings = _derive_data_movement_op_encodings(
onnx_model,
{
name: enc.to_qnn_encoding_dict("2.0.0")
for name, (enc, _) in tensor_to_encoding_map.items()
},
)
# pylint: disable=protected-access
tensor_to_encoding_map |= {
name: (AffineEncoding._from_qnn_encoding_dict(encoding, version="2.0.0"), False)
for name, encoding in derived_encodings.items()
}
return onnx_model, tensor_to_encoding_map
@contextlib.contextmanager
def _concretize_int32_bias_quantizers(model, args, kwargs=None):
if not isinstance(args, (tuple, list)):
args = (args,)
kwargs = kwargs or {}
handles = []
orig_bias_quantizers = {
qmodule: qmodule.param_quantizers["bias"]
for qmodule in model.modules()
if isinstance(qmodule, QuantizationMixin)
and "bias" in qmodule.param_quantizers
and qmodule.bias is not None
}
try:
for qmodule, qtzr in orig_bias_quantizers.items():
if qtzr is not None:
# Bias quantizer already exists.
# This means the user created bias quantizer by him/herself
# In this case, we honor the custom bias quantizer defined by the user
continue
if "weight" in qmodule.param_quantizers and isinstance(
qmodule.param_quantizers["weight"], AffineQuantizerBase
):
# pylint: disable=protected-access
handle = qmodule.register_forward_hook(
type(qmodule)._create_int32_bias_quantizer
)
handles.append(handle)
try:
model(*args, **kwargs)
finally:
for handle in handles:
handle.remove()
yield
finally:
for qmodule, qtzr in orig_bias_quantizers.items():
qmodule.param_quantizers["bias"] = qtzr
@contextlib.contextmanager
def _temporarily_unfold_param_quantizers(model: torch.nn.Module):
# pylint: disable=protected-access
"""
Temporarily re-instantiate param quantizers for ease of export
"""
with contextlib.ExitStack() as stack:
for qmodule in model.modules():
if isinstance(qmodule, QuantizationMixin):
stack.enter_context(qmodule._unfold_param_quantizers())
yield
@contextlib.contextmanager
def _remove_fp16_quantized_parameters(model: torch.nn.Module):
"""
Temporarily detach [b]float16 encodings from pre-quantized parameters
"""
# pylint: disable=protected-access
with contextlib.ExitStack() as stack:
for qmodule in model.modules():
if not isinstance(qmodule, QuantizationMixin):
continue
for name, param in qmodule.named_parameters(recurse=False):
if (
isinstance(param, DequantizedTensor)
and isinstance(param.encoding, FloatEncoding)
and param.encoding._finfo.to_torch_dtype()
in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
):
param = torch.nn.Parameter(param.as_subclass(torch.Tensor))
stack.enter_context(patch_attr(qmodule, name, param))
yield
@contextlib.contextmanager
def _remove_fp16_quantizers(model: torch.nn.Module):
"""
Temporarily remove [b]float16 quantizers for sim.onnx.export,
as sim.onnx.export does NOT support exporting [b]float16 quantizers.
"""
original_containers = {}
try:
for qmodule in model.modules():
if not isinstance(qmodule, QuantizationMixin):
continue
for name, qtzr in qmodule.param_quantizers.items():
if isinstance(qtzr, FloatQuantizeDequantize) and (
qtzr.is_float16() or qtzr.is_bfloat16()
):
original_containers[(qmodule.param_quantizers, name)] = qtzr
qmodule.param_quantizers[name] = None
for i, qtzr in enumerate(qmodule.input_quantizers):
if isinstance(qtzr, FloatQuantizeDequantize) and (
qtzr.is_float16() or qtzr.is_bfloat16()
):
original_containers[(qmodule.input_quantizers, i)] = qtzr
qmodule.input_quantizers[i] = None
for i, qtzr in enumerate(qmodule.output_quantizers):
if isinstance(qtzr, FloatQuantizeDequantize) and (
qtzr.is_float16() or qtzr.is_bfloat16()
):
original_containers[(qmodule.output_quantizers, i)] = qtzr
qmodule.output_quantizers[i] = None
yield
finally:
for (container, key), qtzr in original_containers.items():
container[key] = qtzr
def _to_onnx_qdq(
onnx_model: onnx.ModelProto,
tensor_to_encoding_map: Mapping[str, Tuple[EncodingBase, bool]],
prequantize_constants: bool,
base_dir: str,
) -> onnx.ModelProto:
qnn_encodings = {
name: encoding.to_qnn_encoding_dict("2.0.0")
for name, (encoding, _) in tensor_to_encoding_map.items()
}
qnn_encodings = {
name: encoding for name, encoding in qnn_encodings.items() if encoding
}
qdq_tensor_names = {
fp_tensor_name: f"{fp_tensor_name}_qdq" for fp_tensor_name in qnn_encodings
}
onnx_opset_version = next(
opset.version for opset in onnx_model.opset_import if opset.domain == ""
)
# TODO: Support exporting (b)float16 models
if contains_tensor_type(
onnx_model, (onnx.TensorProto.BFLOAT16, onnx.TensorProto.FLOAT16)
):
raise RuntimeError(
"Exporting to onnx QDQ is only supported for float32 models."
)
float_types = [np.float32 for _ in range(len(qnn_encodings))]
# Add onnx QDQ nodes in batch
_add_onnx_qdq_nodes(
onnx_model,
input_names=qnn_encodings.keys(),
output_names=qdq_tensor_names.values(),
node_name_prefixes=qnn_encodings.keys(),
encodings=qnn_encodings.values(),
float_types=float_types,
onnx_opset=onnx_opset_version,
prequantize_constants=prequantize_constants,
base_dir=base_dir,
)
# Restore model output names from "{output}_qdq" to "{output}"
_restore_model_output_names(onnx_model, qdq_tensor_names)
return onnx_model
def _check_float16_quantizers(module: torch.nn.Module):
for qtzr in module.modules():
if isinstance(qtzr, FloatQuantizeDequantize):
if not qtzr.is_float16() and not qtzr.is_bfloat16():
msg = " ".join(
[
"sim.onnx.export doesn't support exporting floating point encodings",
f"except [b]float16. Got {qtzr.bitwidth}-bit float encoding",
]
)
raise RuntimeError(msg)
def _restore_model_output_names(
onnx_model: onnx.ModelProto, qdq_tensor_name_map: Mapping[str, str]
):
"""
Rename model outputs. Assuming "output" is the model output,
before:
Softmax ----> output -------> QDQ -------> output_qdq
after:
Softmax ----> output__ -----> QDQ -------> output
Args:
onnx_model: onnx model to be modified in-place
qdq_tensor_name_map: mapping from original tensor names to QDQ tensor names
"""
_new_names = {
output.name: f"{output.name}__"
for output in onnx_model.graph.output
if output.name in qdq_tensor_name_map
}
_new_names.update(
{
qdq_tensor_name_map[output.name]: output.name
for output in onnx_model.graph.output
if output.name in qdq_tensor_name_map
}
)
# At this point, _new_names consists of:
# {
# "output": "output__",
# "output_qdq": "output",
# }
#
# Replacing all tensors accordingly will transform the graph as below:
#
# before:
# Softmax ----> output -------> QDQ -------> output_qdq
# after:
# Softmax ----> output__ -----> QDQ -------> output
for node in onnx_model.graph.node:
for i, old_name in enumerate(node.input):
new_name = _new_names.get(old_name, None)
if new_name is not None:
node.input[i] = new_name
for i, old_name in enumerate(node.output):
new_name = _new_names.get(old_name, None)
if new_name is not None:
node.output[i] = new_name
@torch.no_grad()
def _absorb_zero_point_shift(model: torch.nn.Module):
"""
Absorb zero point shift to weights by promoting bitwidth from 2 to 4.
NOTE: This function is only meant for internal testing purpose.
"""
# pylint: disable=redefined-builtin
for qmodule in model.modules():
if not isinstance(qmodule, QuantizationMixin):
continue
for param_name, qtzr in qmodule.param_quantizers.items():
if not isinstance(qtzr, AffineQuantizerBase):
continue
if not qtzr.is_initialized():
continue
if qtzr.zero_point_shift != 0.5:
continue
if not qtzr.symmetric:
continue
weight = getattr(qmodule, param_name)
weight_qdq = qtzr(weight).dequantize()
weight.copy_(weight_qdq)
# weight_qdq ∈ {-1.5 * s, -0.5 * s, 0.5 * s, 1.5 * s }
# = { -3 * s/2, -1 * s/2, 1 * s/2, 3 * s/2}
new_scale = qtzr.get_scale() / 2
qtzr.bitwidth *= 2
qtzr.zero_point_shift = 0.0
min = new_scale * qtzr.qmin
max = new_scale * qtzr.qmax
qtzr.set_range(min, max)
@contextlib.contextmanager
def _temporarily_convert_activation_to(model: torch.nn.Module, signed: bool):
"""
Temporarily convert all signed activation quantizers to unsigned ones for onnx export.
TODO: This is a temporary workaround for QAIRT/HTP bug in handling signed activation encodings.
Remove this once the bug is fixed.
(https://github.qualcomm.com/qualcomm-ai/aimet/issues/5236)
"""
target_activation_quantizers = [
qtzr
for module in model.modules()
if isinstance(module, QuantizationMixin)
for qtzr in itertools.chain(module.input_quantizers, module.output_quantizers)
if isinstance(qtzr, AffineQuantizerBase) and qtzr.signed != signed
]
try:
for qtzr in target_activation_quantizers:
qtzr.signed = signed
yield
finally:
for qtzr in target_activation_quantizers:
qtzr.signed = not signed