# /usr/bin/env python
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Sequential MSE implementation """
# pylint: disable=no-name-in-module, ungrouped-imports, too-many-lines
import copy
from typing import List, Iterable, Dict
from collections import defaultdict
from dataclasses import dataclass
from contextlib import contextmanager
import numpy as np
import onnx
import onnxruntime
import torch
from onnxruntime.quantization.onnx_quantizer import ONNXModel
from onnx.utils import Extractor
from aimet_common.libpymo import TensorQuantizerOpMode
from aimet_common.defs import QuantScheme
from aimet_common.utils import AimetLogger
from aimet_onnx.quantsim import QuantizationSimModel
from aimet_onnx.qc_quantize_op import QcQuantizeOp
from aimet_onnx.sequential_mse.dependency_graph import DependencyGraph, SUPPORTED_MODULES
from aimet_onnx.sequential_mse.dependency_graph import DependencyNode
_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.SeqMse)
[docs]
@dataclass
class SeqMseParams:
"""
Sequential MSE parameters
:param num_batches: Number of batches.
:param num_candidates: Number of candidates to perform grid search. Default 20.
:param inp_symmetry: Input symmetry. Available options are 'asym', 'symfp' and 'symqt'. Default 'symqt'.
:param loss_fn: Loss function. Available options are 'mse', 'l1' and 'sqnr'. Default 'mse'.
"""
num_batches: int
num_candidates: int = 20
inp_symmetry: str = 'symqt'
loss_fn: str = 'mse'
# pylint: disable=too-many-instance-attributes
class SequentialMse:
"""
Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings.
"""
def __init__(self,
model,
sim: QuantizationSimModel,
params: SeqMseParams,
data_loader: Iterable):
"""
Initialize the sequential mse object
:param model: float model
:param sim: QuantizationSimModel object
:param data_loader: Torch Dataloader
:param params: Sequential MSE parameters
"""
# pylint: disable=protected-access
assert sim._quant_scheme in (QuantScheme.post_training_tf, QuantScheme.training_range_learning_with_tf_init), \
"Use TF quant-scheme with sequential MSE."
self.sim = sim
self.model = model
self.params = params
if not isinstance(self.model, ONNXModel):
self.model = ONNXModel(self.model)
# Hacky way to get around onnx.shape_inference.infer_shapes call as it doesn't work for model >2GB
raw_data = {}
# Store and clear raw_data from initializers
for initializer in self.model.model.graph.initializer:
if initializer.HasField('raw_data'):
raw_data[initializer.name] = initializer.raw_data
initializer.ClearField("raw_data")
self._float_extractor = Extractor(self.model.model)
# Restore raw_data to initializers
for initializer in self.model.model.graph.initializer:
if initializer.name in raw_data:
initializer.raw_data = raw_data[initializer.name]
for initializer in self._float_extractor.model.graph.initializer:
if initializer.name in raw_data:
initializer.raw_data = raw_data[initializer.name]
del raw_data
self._sim_extractor = copy.deepcopy(self._float_extractor)
self._update_value_info()
self._sim_extractor.model = self.sim.model.model
self._sim_extractor.graph = self.sim.model.model.graph
self.data_loader = data_loader
self.dependency_graph = DependencyGraph(self.model, data_loader, params.num_batches)
def _update_value_info_for_output(self, node):
"""
Updates the value info for output of a node in sim model.
Value info for QcQuantizeOp is not present in _sim_extractor
:param node: onnx node
"""
input_name = node.input[0]
output_name = node.output[0]
if input_name in self._sim_extractor.vimap and output_name not in self._sim_extractor.vimap:
value_info_for_output = copy.deepcopy(self._sim_extractor.vimap[input_name])
value_info_for_output.name = node.output[0]
self._sim_extractor.vimap[node.output[0]] = value_info_for_output
def _update_value_info_for_input(self, node):
"""
Updates the value info for input of a node in sim model.
Value info for QcQuantizeOp is not present in _sim_extractor
:param node: onnx node
"""
input_name = node.input[0]
output_name = node.output[0]
if output_name in self._sim_extractor.vimap and input_name not in self._sim_extractor.vimap:
value_info_for_input = copy.deepcopy(self._sim_extractor.vimap[output_name])
value_info_for_input.name = node.input[0]
self._sim_extractor.vimap[node.input[0]] = value_info_for_input
def _update_value_info_for_graph_output(self):
"""
Updates the value info for input of a node in sim model.
Value info for QcQuantizeOp is not present in _sim_extractor
:param node: onnx node
"""
for value_info in self.model.model.graph.output:
self._float_extractor.vimap[value_info.name] = value_info
for value_info in self.sim.model.model.graph.output:
self._sim_extractor.vimap[value_info.name] = value_info
def _update_value_info(self):
"""
Updates the value info for sim model.
Value info for QcQuantizeOp is not present in _sim_extractor
"""
self._update_value_info_for_graph_output()
for node in self.sim.model.nodes():
if node.op_type == "QcQuantizeOp":
self._update_value_info_for_output(node)
self._update_value_info_for_input(node)
@staticmethod
def apply_seq_mse(model, sim: QuantizationSimModel, params: SeqMseParams, data_loader):
"""
It performs following steps:
1) creates seq_mse object
2) call apply_seq_algo() member function
:param model: float model
:param sim: QuantizationSimModel object
:param data_loader: Data loader
:param params: Sequential MSE parameters
"""
seq_mse = SequentialMse(model, sim, params, data_loader)
seq_mse.apply_seq_mse_algo()
def apply_seq_mse_algo(self):
"""
It performs following steps:
1) disable the quantizer for unsupported modules
2) create the dependency graph
3) run the onnx graph and compute encoding using seq mse algorithm
4) re-enable the quantizer disabled in first step
"""
with self.temporarily_disable_quantizers():
self._topological_traversal()
def _get_quantizers_to_be_disabled(self) -> List[QcQuantizeOp]:
"""
Get list of quantizers to be disabled in sim model before applying seq mse.
NOTE: Disable all activation quantizers and param quantizers of non-supported modules
:return Returns the quantizers to be disabled in sim model.
"""
enabled_quantizer_names = []
# Get list of all the enabled activation + param quantizer names
for name, quantizer in self.sim.qc_quantize_op_dict.items():
if quantizer.enabled:
enabled_quantizer_names.append(name)
# Get list of all the enabled param quantizers of supported ops
param_quantizer_names = []
for cg_op in self.dependency_graph.conn_graph.ordered_ops:
if cg_op.type in SUPPORTED_MODULES:
for param_name in cg_op.parameters:
if param_name in self.sim.qc_quantize_op_dict and self.sim.qc_quantize_op_dict[param_name].enabled:
param_quantizer_names.append(param_name)
# Get list of all the quantizers that are not part of param quantizers of supported ops
quantizers_to_be_disabled = []
for name in enabled_quantizer_names:
if name not in param_quantizer_names:
quantizers_to_be_disabled.append(self.sim.qc_quantize_op_dict[name])
return quantizers_to_be_disabled
@contextmanager
def temporarily_disable_quantizers(self):
"""
Disable quantizers needed to be disabled before applying sequential MSE.
"""
quantizers_to_be_disabled = self._get_quantizers_to_be_disabled()
try:
for quantizer in quantizers_to_be_disabled:
quantizer.enabled = False
yield
finally:
for quantizer in quantizers_to_be_disabled:
quantizer.enabled = True
def _get_min_max_from_weights(self, dependency_node: DependencyNode):
"""
Get per channel min/max values across output channel.
:param dependency_node: Dependevy node which is to be optimized
:return: per_channel_min and per_channel_max
"""
# pylint: disable=protected-access
channel_axis = QuantizationSimModel._get_quantization_axes(dependency_node.cg_op)[0]
weight_data = self.dependency_graph.get_param_value(dependency_node)
# Handle negative indexing
if channel_axis < 0:
channel_axis += len(weight_data.shape)
axis = tuple(i for i in range(len(weight_data.shape)) if i != channel_axis)
per_channel_max = np.max(abs(weight_data), axis=axis)
return [-per_channel_max, per_channel_max]
def _get_candidates(self, per_channel_max, per_channel_min):
"""
Perform grid search.
:param per_channel_max: Per channel max values
:param per_channel_min: Per channel min values
:return: candidates
"""
candidates = []
num_candidates = self.params.num_candidates
for i in range(num_candidates):
cand_max = per_channel_max / num_candidates * (i + 1)
cand_min = per_channel_min / num_candidates * (i + 1)
candidates.append((cand_max, cand_min))
return candidates
def _compute_encoding_from_candidate(self, candidate, dependency_node: DependencyNode):
"""
computes the encoding using candidate min and candidate max
:param candidate: list containing min and max value
:param dependency_node: Corresponding Dependency node
:return: encoding
"""
cand_max = candidate[0]
cand_min = candidate[1]
cand = np.stack((cand_max, cand_min), axis=-1)
weight_name = self.dependency_graph.get_param_name(dependency_node)
quantize_op = self.sim.qc_quantize_op_dict[weight_name]
quantize_op.reset_encoding_stats()
# pylint: disable=protected-access
quantizer_shape = quantize_op._encoding_shape()
num_encodings = np.prod(quantizer_shape)
if num_encodings != len(cand) and num_encodings != 1:
raise ValueError(weight_name, " should be per-tensor or number of "
"quantizer must match with number of channels")
if quantizer_shape:
quantize_op.update_encoding_stats(np.reshape(cand, (*quantizer_shape[0:-1], 2 * quantizer_shape[-1])))
else:
quantize_op.update_encoding_stats(cand)
quantize_op.compute_encodings()
quantize_op.op_mode = TensorQuantizerOpMode.quantizeDequantize
def _freeze_encodings(self, dependency_node: DependencyNode):
"""
Freezes the encoding after the node is optimized
:param dependency_node: Optimized dependency node
"""
weight_name = self.dependency_graph.get_param_name(dependency_node)
quantize_op = self.sim.qc_quantize_op_dict[weight_name]
quantize_op.freeze_encodings()
@staticmethod
def neg_sqnr(pred: torch.Tensor, target: torch.Tensor, eps=1e-10, reduction="none"):
"""
Loss function to minimize negative SQNR which is equivalent to maximizing SQNR.
:param pred: X^Q^ quantized-dequantized values
:param target: XW FP32 values
:param eps: epsilon
:param reduction: unused arg added only to have the same signature as that of functional losses of pytorch library
:return: Negative SQNR
"""
# pylint: disable=unused-argument
quant_error = target - pred
exp_noise = torch.mean(quant_error ** 2, 0, keepdim=True) + eps
exp_signal = torch.mean(target ** 2, 0, keepdim=True)
sqnr = exp_signal / exp_noise
sqnr_db = 10 * torch.log10(sqnr)
return -sqnr_db
def _compute_recon_loss(self, sim_output, float_output, dependency_node):
"""
Compute reconstruction loss and return the sum by reducing over all the dimensions except last channel dimension.
:param xqwq: X^Q^ quantized-dequantized values
:param xw: XW FP32 values
:param params: Sequential MSE parameters
:return: loss
"""
xqwq = torch.from_numpy(sim_output)
xw = torch.from_numpy(float_output)
if dependency_node.cg_op.type == "Conv":
permute_order = [0] + list(range(2, xw.dim())) + [1]
xqwq = xqwq.permute(permute_order)
xw = xw.permute(permute_order)
if self.params.loss_fn == "mse":
loss_fn = torch.nn.functional.mse_loss
elif self.params.loss_fn == "l1":
loss_fn = torch.nn.functional.l1_loss
elif self.params.loss_fn == "sqnr":
loss_fn = SequentialMse.neg_sqnr
else:
raise ValueError(f"Invalid loss function: {self.params.loss_fn}")
channel_dim = xqwq.shape[-1]
xqwq = xqwq.reshape(-1, channel_dim)
xw = xw.reshape(-1, channel_dim)
loss = loss_fn(xqwq, xw, reduction="none").sum(0)
assert loss.size() == torch.Size([channel_dim])
return np.array(loss)
def _run_seq_mse(self, dep_nodes_to_parallelize: List[DependencyNode]):
"""
Run Sequential MSE for all the dep_nodes_to_parallelize at same level.
:param dep_nodes_to_parallelize: Dependency nodes to be parallelized.
"""
def _set_candidates(index: int):
"""
Helper function to set candidate based on index for ops at same level.
Internally computes the encoding using candidate min and candidate max
:param index: Index of candidate
"""
for dep_node in dep_nodes_to_parallelize:
candidate = min_max_candidates[dep_node.cg_op.name]
self._compute_encoding_from_candidate(candidate[index], dep_node)
def _compute_loss(all_fp_outputs: List, all_sim_outputs: List):
"""
Helper function to compute reconstruction loss for ops at same level.
:param all_fp_outputs: FP Outputs of all ops at same level
:param all_sim_outputs: Sim Outputs of all ops at same level
"""
for i, dep_node in enumerate(dep_nodes_to_parallelize):
fp_output = np.concatenate(all_fp_outputs[i], axis=0)
sim_output = np.concatenate(all_sim_outputs[i], axis=0)
loss = self._compute_recon_loss(fp_output, sim_output, dep_node)
total_loss[dep_node.cg_op.name].append(loss)
def _get_dep_node_io_names(dep_nodes: List[DependencyNode]):
"""
Helper function to get the input and output names of subgraph.
:param dep_nodes: List of dependency nodes to be parallelized.
:return: Subgraph input and output names.
"""
subgraph_inputs = []
subgraph_outputs = []
for dep_node in dep_nodes:
subgraph_inputs.append(dep_node.op_input_names[0])
subgraph_outputs.append(dep_node.op_output_names[0])
subgraph_inputs = list(set(subgraph_inputs))
return subgraph_inputs, subgraph_outputs
total_loss = defaultdict(list)
min_max_candidates = {}
# Perform grid search
for dep_node in dep_nodes_to_parallelize:
per_channel_min, per_channel_max = self._get_min_max_from_weights(dep_node)
min_max_candidates[dep_node.cg_op.name] = self._get_candidates(per_channel_max, per_channel_min)
subgraph_inp_names, subgraph_outs_names = _get_dep_node_io_names(dep_nodes_to_parallelize)
# For now, we only expose "symqt" input symmetry.
assert self.params.inp_symmetry == "symqt", "Only symmetric quantsim inputs ('symqt') are supported."
sim_inputs = self.dependency_graph.get_sim_data(dep_nodes_to_parallelize)
# Create inference session for subgraph from float model
fp_subgraph_model = self._split_onnx_graph(self._float_extractor, subgraph_inp_names, subgraph_outs_names)
with self._create_session(fp_subgraph_model) as session:
fp_outputs = self._run_onnx_graph(session, sim_inputs)
# Create inference session for subgraph from sim model
sim_subgraph_model = self._split_onnx_graph(self._sim_extractor, subgraph_inp_names, subgraph_outs_names)
with self._create_session(sim_subgraph_model) as session:
for i in range(self.params.num_candidates):
_set_candidates(i)
sim_outputs = self._run_onnx_graph(session, sim_inputs)
_compute_loss(fp_outputs, sim_outputs)
_logger.debug(f"Finished candidate: {i}")
del fp_outputs, sim_outputs, sim_inputs
# Postprocessing (not vectorized)
for dep_node in dep_nodes_to_parallelize:
loss = total_loss[dep_node.cg_op.name]
stacked_loss = np.stack(loss, axis=0)
arg_min_ = np.argmin(stacked_loss, axis=0, keepdims=True)
best_max = \
torch.stack([torch.tensor(cand_max) for cand_max, _ in min_max_candidates[dep_node.cg_op.name]]).gather(
0, torch.tensor(arg_min_))[0]
best_min = \
torch.stack([torch.tensor(cand_min) for _, cand_min in min_max_candidates[dep_node.cg_op.name]]).gather(
0, torch.tensor(arg_min_))[0]
best_candidate = (best_max, best_min)
self._compute_encoding_from_candidate(best_candidate, dep_node)
self._freeze_encodings(dep_node)
dep_node_names = [dep_node.cg_op.name for dep_node in dep_nodes_to_parallelize]
_logger.info(f"Computed optimal parameter encodings for ops: {', '.join(dep_node_names)}")
@staticmethod
def _split_onnx_graph(extractor: Extractor, input_names: List[str], output_names: List[str]) -> onnx.ModelProto:
"""
Splits the onnx graph from input names to output names using extractor
:param input_names: input names of split graph
:param output_names: output names of split graph
:return: float split model and sim split model
"""
return extractor.extract_model(list(input_names), list(output_names))
def _run_onnx_graph(self, session: onnxruntime.InferenceSession, inputs: Dict) -> List[List[np.ndarray]]:
"""
Run the onnx graph using onnx runtime
:param session: Onnxruntime session
:param inputs: inputs to the model
:return: outputs
"""
outputs = []
for i in range(self.params.num_batches):
input_batch = {}
for name, data in inputs.items():
input_batch[name] = data[i]
output = session.run(None, input_batch)
if len(outputs) == 0:
outputs = [[] for _ in range(len(output))]
for idx, out in enumerate(output):
outputs[idx].append(out)
return outputs
def _cache_subgraph_input_data(self, dep_nodes: List[DependencyNode]):
"""
For given dependency nodes at the same level, cache intermediate activation data
- Extract a subgraph using the parent nodes,
- Collect the intermediate activations by executing the subgraph,
- Cache these data to provide them to the next subgraph.
:param dep_nodes: List of dependency nodes at same level
"""
dep_node_names = [dep_node.cg_op.name for dep_node in dep_nodes]
_logger.debug(f"Started caching inputs for dep nodes: {', '.join(dep_node_names)}")
subgraph_inp_names, subgraph_out_names = self.dependency_graph.get_subgraph_inp_out_names(dep_nodes)
subgraph_inps = self.dependency_graph.dependency_node_inputs(dep_nodes)
assert len(subgraph_inp_names) == len(subgraph_inps)
_logger.debug(f"Subgraph input names: {subgraph_inp_names}, Subgraph output names: {subgraph_out_names}")
sim_split_model = self._split_onnx_graph(self._sim_extractor, subgraph_inp_names, subgraph_out_names)
with self._create_session(sim_split_model) as session:
subgraph_outs = self._run_onnx_graph(session, subgraph_inps)
self.dependency_graph.update_sim_data(subgraph_out_names, subgraph_outs)
_logger.debug(f"Collected intermediate data for output names: {subgraph_out_names}")
del subgraph_inps, subgraph_outs
# Decrease the reference count for the input data.
for dep_node in dep_nodes:
for inward_node in dep_node.inward_nodes:
inward_node.out_degree = inward_node.out_degree - 1
if inward_node.out_degree == 0:
self.dependency_graph.dec_ref_count(inward_node)
def _topological_traversal(self):
"""
Start the topo sort from the starting ops i.e. ops having in_degree equal to zero
Flow:
- Cache intermediate activations input data before applying Seq MSE at a given level in topological order.
- Use cached intermediate activations and run Seq MSE in parallel.
NOTE: For the first iteration, no need to cache subgraph input data since model graph inputs are already saved.
"""
sorted_order = self.dependency_graph.get_topologically_sorted_nodes()
for i, sorted_nodes in sorted_order.items():
if i != 0:
self._cache_subgraph_input_data([node for node in sorted_nodes])
dep_nodes_to_parallelize = [node for node in sorted_nodes if node.cg_op.type in SUPPORTED_MODULES]
if dep_nodes_to_parallelize:
self._run_seq_mse(dep_nodes_to_parallelize)
@contextmanager
def _create_session(self, model: onnx.ModelProto):
"""
Build and return onnxruntime inference session
:param model: onnx model
:return: Session
"""
try:
session = QuantizationSimModel.build_session(model, self.sim.providers,
user_onnx_libs=self.sim._user_onnx_libs, path=self.sim._path)
yield session
finally:
del session