# /usr/bin/env python
# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
""" Sequential MSE implementation """
# pylint: disable=no-name-in-module, ungrouped-imports, too-many-lines
import copy
from typing import List
import numpy as np
import torch
from onnxruntime.quantization.onnx_quantizer import ONNXModel
from onnx import numpy_helper
from onnx.utils import Extractor
from aimet_onnx.quantsim import QuantizationSimModel
from aimet_onnx.qc_quantize_op import QcQuantizeOp
from aimet_onnx.sequential_mse.dependency_graph_utils import DependencyGraphUtils
from aimet_onnx.sequential_mse.dependency_graph import DependencyGraph
from aimet_onnx.sequential_mse.dependency_graph import DependencyNode
from aimet_common.libpymo import TensorQuantizerOpMode
from aimet_common.defs import QuantScheme
from aimet_onnx.meta.connectedgraph import ConnectedGraph
from aimet_common.utils import AimetLogger
from dataclasses import dataclass
from onnx import TensorProto
SUPPORTED_MODULES = ("Conv", "Gemm", "MatMul")
_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.SeqMse)
[docs]
@dataclass
class SeqMseParams:
    """
    Sequential MSE parameters
    :param num_batches: Number of batches.
    :param num_candidates: Number of candidates to perform grid search. Default 20.
    :param inp_symmetry: Input symmetry. Available options are 'asym', 'symfp' and 'symqt'. Default 'symqt'.
    :param loss_fn: Loss function. Available options are 'mse', 'l1' and 'sqnr'. Default 'mse'.
    """
    num_batches: int = 4
    num_candidates: int = 20
    inp_symmetry: str = 'symqt'
    loss_fn: str = 'mse' 
# pylint: disable=too-many-instance-attributes
class SequentialMse:
    """
    Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings.
    """
    def __init__(self,
                 model,
                 sim: QuantizationSimModel,
                 params: SeqMseParams,
                 data_loader):
        """
        Initialize the sequential mse object
        :param model: float model
        :param sim: QuantizationSimModel object
        :param data_loader: Torch Dataloader
        :param params: Sequential MSE parameters
        """
        # pylint: disable=protected-access
        assert sim._quant_scheme in (QuantScheme.post_training_tf, QuantScheme.training_range_learning_with_tf_init), \
            "Use TF quant-scheme with sequential MSE."
        self.sim = sim
        self.model = model
        self.params = params
        self.node_name_to_input_names = {}
        self.static_tensor_name_to_proto = {}
        if not isinstance(self.model, ONNXModel):
            self.model = ONNXModel(self.model)
        self._fill_node_name_to_input_names()
        self._fill_static_tensor_name_to_proto()
        raw_data = {}
        for initializer in self.model.model.graph.initializer:
            if initializer.HasField('raw_data'):
                raw_data[initializer.name] = initializer.raw_data
                initializer.ClearField("raw_data")
        self._float_extractor = Extractor(self.model.model)
        for initializer in self.model.model.graph.initializer:
            if initializer.name in raw_data:
                initializer.raw_data = raw_data[initializer.name]
        for initializer in self._float_extractor.model.graph.initializer:
            if initializer.name in raw_data:
                initializer.raw_data = raw_data[initializer.name]
        self._sim_extractor = copy.deepcopy(self._float_extractor)
        self._update_value_info()
        self._sim_extractor.model = self.sim.model.model
        self._sim_extractor.graph = self.sim.model.model.graph
        self.connected_graph = ConnectedGraph(self.model)
        self.data_loader = data_loader
        self.dependency_graph = DependencyGraph()
        self.dependency_graph_utils = DependencyGraphUtils(self.connected_graph, self.node_name_to_input_names,
                                                           self.static_tensor_name_to_proto)
        self.quantizers_to_be_disabled = self._get_quantizers_to_be_disabled() # check this
    def _update_value_info_for_output(self, node):
        """
        Updates the value info for output of a node in sim model.
        Value info for QcQuantizeOp is not present in _sim_extractor
        :param node: onnx node
        """
        input_name = node.input[0]
        output_name = node.output[0]
        if input_name in self._sim_extractor.vimap and output_name not in self._sim_extractor.vimap:
            value_info_for_output = copy.deepcopy(self._sim_extractor.vimap[input_name])
            value_info_for_output.name = node.output[0]
            self._sim_extractor.vimap[node.output[0]] = value_info_for_output
    def _update_value_info_for_input(self, node):
        """
        Updates the value info for input of a node in sim model.
        Value info for QcQuantizeOp is not present in _sim_extractor
        :param node: onnx node
        """
        input_name = node.input[0]
        output_name = node.output[0]
        if output_name in self._sim_extractor.vimap and input_name not in self._sim_extractor.vimap:
            value_info_for_input = copy.deepcopy(self._sim_extractor.vimap[output_name])
            value_info_for_input.name = node.input[0]
            self._sim_extractor.vimap[node.input[0]] = value_info_for_input
    def _update_value_info_for_graph_output(self):
        """
        Updates the value info for input of a node in sim model.
        Value info for QcQuantizeOp is not present in _sim_extractor
        :param node: onnx node
        """
        for value_info in self.model.model.graph.output:
            self._float_extractor.vimap[value_info.name] = value_info
        for value_info in self.sim.model.model.graph.output:
            self._sim_extractor.vimap[value_info.name] = value_info
    def _update_value_info(self):
        """
        Updates the value info for sim model.
        Value info for QcQuantizeOp is not present in _sim_extractor
        """
        self._update_value_info_for_graph_output()
        for node in self.sim.model.nodes():
            if node.op_type == "QcQuantizeOp":
                self._update_value_info_for_output(node)
                self._update_value_info_for_input(node)
    def _fill_static_tensor_name_to_proto(self):
        """
        Fills the mapping from static tensor name to the prop buf
        """
        for initializer in self.model.model.graph.initializer:
            self.static_tensor_name_to_proto[initializer.name] = initializer
        for node in self.model.model.graph.node:
            if node.op_type == "Constant":
                self.static_tensor_name_to_proto[node.output[0]] = node
    # pylint: disable=inconsistent-return-statements
    def _extract_float_data_from_proto(self, name):
        """
        returns the tensor value of the given name using static_tensor_name_to_proto dictionary
        :param name: name of the static tensor
        :return tensor value
        """
        if name in self.static_tensor_name_to_proto:
            proto_buf = self.static_tensor_name_to_proto[name]
            if isinstance(proto_buf, TensorProto):
                return numpy_helper.to_array(proto_buf)
            for attr in proto_buf.attribute:
                if attr.name == "value":
                    return numpy_helper.to_array(attr.t)
        else:
            raise ValueError(name, " is neither constant or initializer")
    def _fill_node_name_to_input_names(self):
        """
        Fills the mapping from node name to input names
        """
        for node in self.model.nodes():
            self.node_name_to_input_names[node.name] = node.input
    @staticmethod
    def apply_seq_mse(model, sim: QuantizationSimModel, params: SeqMseParams, data_loader):
        """
        It performs following steps:
        1) creates seq_mse object
        2) call apply_seq_algo() member function
        :param model: float model
        :param sim: QuantizationSimModel object
        :param data_loader: Data loader
        :param params: Sequential MSE parameters
        """
        seq_mse = SequentialMse(model, sim, params, data_loader)
        seq_mse.apply_seq_mse_algo()
    def apply_seq_mse_algo(self):
        """
        It performs following steps:
        1) disable the quantizer for unsupported modules
        2) create the dependency graph
        3) run the onnx graph and compute encoding using seq mse algorithm
        4) re-enable the quantizer disabled in first step
        """
        try:
            self.temporarily_disable_quantizers()
            self.dependency_graph = self.dependency_graph_utils.create_dependency_graph(self.data_loader,
                                                                                        self.params.num_batches)
            self._run_onnx_graph_dependency_graph_order()
        finally:
            self.re_enable_quantizers()
    def _get_quantizers_to_be_disabled(self) -> List[QcQuantizeOp]:
        """
        :return Returns the quantizers of unsupported modules
        """
        quantizer_to_disable_name = []
        quantizer_to_not_disable_name = []
        for name, qc_quantize_op in self.sim.qc_quantize_op_dict.items():
            if qc_quantize_op.enabled:
                quantizer_to_disable_name.append(name)
        for node in self.model.nodes():
            if self.dependency_graph_utils.is_supported_module(node):
                weight_node_name = node.input[1]
                quantizer_to_not_disable_name.append(weight_node_name)
        quantizer_to_disable_name = [name for name in quantizer_to_disable_name
                                     if name not in quantizer_to_not_disable_name]
        quantizer_to_disable = []
        for name in quantizer_to_disable_name:
            if name in self.sim.qc_quantize_op_dict:
                quantizer_to_disable.append(self.sim.qc_quantize_op_dict[name])
        return quantizer_to_disable
    def temporarily_disable_quantizers(self):
        """
        Disable quantizers needed to be disabled before applying sequential MSE.
        """
        for quantizer in self.quantizers_to_be_disabled:
            quantizer.enabled = False
    def re_enable_quantizers(self):
        """
        Re-enable quantizers that were disabled by temporarily_disable_quantizers method
        """
        for quantizer in self.quantizers_to_be_disabled:
            quantizer.enabled = True
    def _get_min_max_from_weights(self, dependency_node: DependencyNode):
        """
        Get per channel min/max values across output channel.
        :param dependency_node: Dependevy node which is to be optimized
        :return: per_channel_min and per_channel_max
        """
        weight_name = self.node_name_to_input_names[dependency_node.op_name][1]
        weight_data = self._extract_float_data_from_proto(weight_name)
        connected_op = self.connected_graph.get_op_from_module_name(dependency_node.op_name)
        # pylint: disable=protected-access
        channel_axis = QuantizationSimModel._get_quantization_axes(connected_op)[0]
        # Handle negative indexing
        if channel_axis  < 0:
            channel_axis +=  len(weight_data.shape)
        axis = tuple(i for i in range(len(weight_data.shape)) if i != channel_axis)
        per_channel_max = np.max(abs(weight_data), axis=axis)
        return [-per_channel_max, per_channel_max]
    def _get_candidates(self, per_channel_max, per_channel_min):
        """
        Perform grid search.
        :param per_channel_max: Per channel max values
        :param per_channel_min: Per channel min values
        :return: candidates
        """
        candidates = []
        num_candidates = self.params.num_candidates
        for i in range(num_candidates):
            cand_max = per_channel_max / num_candidates * (i + 1)
            cand_min = per_channel_min / num_candidates * (i + 1)
            candidates.append((cand_max, cand_min))
        return candidates
    def _compute_encoding_from_candidate(self, candidate, dependency_node: DependencyNode):
        """
        computes the encoding using candidate min and candidate max
        :param candidate: list containing min and max value
        :param dependency_node: Corresponding Dependency node
        :return: encoding
        """
        cand_max = candidate[0]
        cand_min = candidate[1]
        cand = np.stack((cand_max, cand_min), axis=-1)
        weight_name = self.node_name_to_input_names[dependency_node.op_name][1]
        quantize_op = self.sim.qc_quantize_op_dict[weight_name]
        quantize_op.reset_encoding_stats()
        # pylint: disable=protected-access
        quantizer_shape = quantize_op._encoding_shape()
        num_encodings = np.prod(quantizer_shape)
        if num_encodings != len(cand) and num_encodings != 1:
            raise ValueError(weight_name, " should be per-tensor or number of "
                                          "quantizer must match with number of channels")
        if quantizer_shape:
            quantize_op.update_encoding_stats(np.reshape(cand, (*quantizer_shape[0:-1], 2 * quantizer_shape[-1])))
        else:
            quantize_op.update_encoding_stats(cand)
        quantize_op.compute_encodings()
        quantize_op.op_mode = TensorQuantizerOpMode.quantizeDequantize
    def _freeze_encodings(self, dependency_node: DependencyNode):
        """
        Freezes the encoding after the node is optimized
        :param dependency_node: Optimized dependency node
        """
        weight_name = self.node_name_to_input_names[dependency_node.op_name][1]
        quantize_op = self.sim.qc_quantize_op_dict[weight_name]
        quantize_op.freeze_encodings()
    @staticmethod
    def neg_sqnr(pred: torch.Tensor, target: torch.Tensor, eps=1e-10, reduction="none"):
        """
        Loss function to minimize negative SQNR which is equivalent to maximizing SQNR.
        :param pred: X^Q^ quantized-dequantized values
        :param target: XW FP32 values
        :param eps: epsilon
        :param reduction: unused arg added only to have the same signature as that of functional losses of pytorch library
        :return: Negative SQNR
        """
        # pylint: disable=unused-argument
        quant_error = target - pred
        exp_noise = torch.mean(quant_error ** 2, 0, keepdim=True) + eps
        exp_signal = torch.mean(target ** 2, 0, keepdim=True)
        sqnr = exp_signal / exp_noise
        sqnr_db = 10 * torch.log10(sqnr)
        return -sqnr_db
    def _compute_recon_loss(self, sim_output, float_output, dependency_node):
        """
        Compute reconstruction loss and return the sum by reducing over all the dimensions except last channel dimension.
        :param xqwq: X^Q^ quantized-dequantized values
        :param xw: XW FP32 values
        :param params: Sequential MSE parameters
        :return: loss
        """
        xqwq = torch.from_numpy(sim_output)
        xw = torch.from_numpy(float_output)
        if dependency_node.op_type == "Conv":
            permute_order = [0] + list(range(2, xw.dim())) + [1]
            xqwq = xqwq.permute(permute_order)
            xw = xw.permute(permute_order)
        if self.params.loss_fn == "mse":
            loss_fn = torch.nn.functional.mse_loss
        elif self.params.loss_fn == "l1":
            loss_fn = torch.nn.functional.l1_loss
        elif self.params.loss_fn == "sqnr":
            loss_fn = SequentialMse.neg_sqnr
        else:
            raise ValueError(f"Invalid loss function: {self.params.loss_fn}")
        channel_dim = xqwq.shape[-1]
        xqwq = xqwq.reshape(-1, channel_dim)
        xw = xw.reshape(-1, channel_dim)
        loss = loss_fn(xqwq, xw, reduction="none").sum(0)
        assert loss.size() == torch.Size([channel_dim])
        return np.array(loss)
    # pylint: disable-msg=too-many-locals
    def _do_seq_mse(self, dependency_node: DependencyNode):
        """
        Find and freeze optimal parameter encodings candidate for given dependency node.
        :param dependency_node: Corresponding Dependency node
        """
        per_channel_min, per_channel_max = self._get_min_max_from_weights(dependency_node)
        candidates = self._get_candidates(per_channel_max, per_channel_min)
        total_loss = []
        float_split_model, sim_split_model = self._split_onnx_graph(dependency_node.op_input_names,
                                                                    dependency_node.op_output_names)
        _logger.info("Finding and freezing optimal param encodings candidate of op: %s", dependency_node.op_name)
        # for different modes only inputs will change
        if self.params.inp_symmetry == "asym":
            float_inputs = self.dependency_graph.get_float_data(dependency_node)
            sim_inputs = self.dependency_graph.get_sim_data(dependency_node)
        elif self.params.inp_symmetry == "symfp":
            float_inputs = self.dependency_graph.get_float_data(dependency_node)
            sim_inputs = self.dependency_graph.get_float_data(dependency_node)
        elif self.params.inp_symmetry == "symqt":
            float_inputs = self.dependency_graph.get_sim_data(dependency_node)
            sim_inputs = self.dependency_graph.get_sim_data(dependency_node)
        else:
            raise ValueError(f"Invalid inp_symmetry: {self.params.inp_symmetry}")
        float_outputs = self._run_onnx_graph(float_split_model, float_inputs)
        float_outputs = np.concatenate(float_outputs[0], axis=0)
        for candidate in candidates:
            self._compute_encoding_from_candidate(candidate, dependency_node)
            sim_outputs = self._run_onnx_graph(sim_split_model, sim_inputs)
            sim_outputs = np.concatenate(sim_outputs[0], axis=0)
            loss = self._compute_recon_loss(sim_outputs, float_outputs, dependency_node)
            total_loss.append(loss)
        stacked_loss = np.stack(total_loss, axis=0)
        arg_min_ = np.argmin(stacked_loss, axis=0, keepdims=True)
        best_max = torch.stack([torch.tensor(cand_max) for cand_max, _ in candidates]).gather(0, torch.tensor(arg_min_))[0]
        best_min = torch.stack([torch.tensor(cand_min) for _, cand_min in candidates]).gather(0, torch.tensor(arg_min_))[0]
        best_candidate = (best_max, best_min)
        self._compute_encoding_from_candidate(best_candidate, dependency_node)
        self._freeze_encodings(dependency_node)
    def _get_input_names_from_dependencies(self, dependency_node: DependencyNode):
        """
        Returns the input names for the op corresponding to dependency node
        :param dependency_node: Corresponding Dependency node
        :return: input names for the op corresponding to dependency node
        """
        input_names = set()
        for inward_node in dependency_node.inward_nodes:
            input_names.update(inward_node.op_input_names)
        return list(input_names)
    def _get_inputs_from_dependencies(self, dependency_node: DependencyNode):
        """
        Returns the input needed for the op corresponding to the dependency node
        :param dependency_node: Corresponding dependency node
        :return: float inputs and sim inputs
        """
        float_inputs = {}
        sim_inputs = {}
        for inward_node in dependency_node.inward_nodes:
            float_inputs.update(self.dependency_graph.get_float_data(inward_node))
            sim_inputs.update(self.dependency_graph.get_sim_data(inward_node))
        return float_inputs, sim_inputs
    def _split_onnx_graph(self, input_names, output_names):
        """
        Splits the onnx graph from input names to output names using extractor
        :param input_names: input names of split graph
        :param output_names: output names of split graph
        :return: float split model and sim split model
        """
        float_split_model = self._float_extractor.extract_model(list(input_names), list(output_names))
        sim_split_model = self._sim_extractor.extract_model(list(input_names), list(output_names))
        return float_split_model, sim_split_model
    def _run_onnx_graph(self, model, inputs):
        """
        Run the onnx graph using onnx runtime
        :param model: Corresponding model
        :param inputs: inputs to the model
        :return: outputs
        """
        # pylint: disable=protected-access
        session = QuantizationSimModel.build_session(model, self.sim.providers,
                                                     user_onnx_libs=self.sim._user_onnx_libs, path=self.sim._path)
        outputs = []
        num_batches = min(self.params.num_batches, len(self.data_loader.dataset) // self.data_loader.batch_size)
        for i in range(num_batches):
            input_batch = {}
            for name, data in inputs.items():
                input_batch[name] = data[i]
            output = session.run(None, input_batch)
            if len(outputs) == 0:
                outputs = [[] for _ in range(len(output))]
            for idx, out in enumerate(output):
                outputs[idx].append(out)
        return outputs
    def _process_dependency_nodes(self, dependency_node: DependencyNode):
        """
        1) Get input names, output names using dependency graph
        2) Split the graph using input names and output names
        3) Run the split graph
        4) Decrease the out-degree of the inward nodes by -1, if outdegree becomes zero, then delete the data
        5) Optimize the dependency node
        :param dependency_node: Corresponding dependency node
        """
        # get input names and output names, split and run and do_seq_mse
        # then make out_degree of the inward nodes -1, if that becomes zero delete the data
        input_names = self._get_input_names_from_dependencies(dependency_node=dependency_node)
        graph_inputs = [node.name for node in self.model.model.graph.input]
        output_names = [name for name in dependency_node.op_input_names if name not in graph_inputs]
        float_split_model, sim_split_model = self._split_onnx_graph(input_names=input_names, output_names=output_names)
        float_inputs, sim_inputs = self._get_inputs_from_dependencies(dependency_node=dependency_node)
        float_outputs = self._run_onnx_graph(model=float_split_model, inputs=float_inputs)
        self.dependency_graph.update_float_data(output_names, float_outputs)
        sim_outputs = self._run_onnx_graph(model=sim_split_model, inputs=sim_inputs)
        self.dependency_graph.update_sim_data(output_names, sim_outputs)
        for inward_node in dependency_node.inward_nodes:
            inward_node.outdegree = inward_node.outdegree - 1
            if inward_node.outdegree == 0:
                self.dependency_graph.dec_ref_count(inward_node)
        if dependency_node.op_type in SUPPORTED_MODULES:
            self._do_seq_mse(dependency_node=dependency_node)
    def _do_topo_sort_helper(self, dependency_node: DependencyNode):
        """
        1) Decrease indegree of the child ops by -1, if the indegree becomes zero, then process the node
        2) run _do_topo_sort_helper for the child node
        :param dependency_node: Corresponding dependency node
        """
        # make indegree of the child ops -1, if the indegree becomes zero split and run and do_seq_mse
        # then make out_degree of the inward nodes -1, if that becomes zero delete the data
        # then call _do_topo_sort_helper for that node
        for child_node in dependency_node.outward_nodes:
            child_node.indegree = child_node.indegree - 1
            if child_node.indegree == 0:
                self._process_dependency_nodes(dependency_node=child_node)
                self._do_topo_sort_helper(dependency_node=child_node)
    def _run_onnx_graph_dependency_graph_order(self):
        """
        Start the topo sort from the starting ops i.e. ops having indegree equal to zero
        """
        for start_op in self.dependency_graph.starting_ops:
            if start_op.op_type in SUPPORTED_MODULES:
                self._do_seq_mse(dependency_node=start_op)
            self._do_topo_sort_helper(dependency_node=start_op)