# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
""" Find quantizer groups in a model """
import itertools
from typing import Dict, List, Tuple
from collections import defaultdict
from dataclasses import dataclass, field
import torch
from aimet_common.connected_graph.connectedgraph_utils import CG_SPLIT
from aimet_common.connected_graph.operation import Op
from aimet_common.utils import AimetLogger
from aimet_common.amp.utils import CANDIDATE_WITH_DTYPE
from aimet_common.amp.quantizer_groups import QuantizerGroupBase, get_supported_candidates_for_quantizers, \
    compute_baseline_candidate_options
from aimet_torch.meta.connectedgraph import ConnectedGraph
from aimet_torch.qc_quantize_op import QcQuantizeWrapper
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch import onnx_utils
from aimet_torch.translation_mapping import aimet_op_to_backend_op_name_map
logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.MixedPrecision)
[docs]@dataclass(frozen=True)
class QuantizerGroup(QuantizerGroupBase):
    """
    Group of modules and quantizers
    """
    input_quantizers: Tuple[str, ...] = field(default_factory=tuple)
    output_quantizers: Tuple[str, ...] = field(default_factory=tuple)
    parameter_quantizers: Tuple[str, ...] = field(default_factory=tuple)
[docs]    def get_candidate(self, name_to_quantizer_dict: Dict) -> CANDIDATE_WITH_DTYPE:
        """
        Gets Activation & parameter bitwidth
        :param name_to_quantizer_dict: Gets module from module name
        :return: Tuple of Activation, parameter bitwidth and data type
        """
        activation_bw, parameter_bw = None, None
        activation_data_type, parameter_data_type = None, None
        for quantizer in self._get_input_quantizers(name_to_quantizer_dict) +\
                         
self._get_output_quantizers(name_to_quantizer_dict):
            activation_bw = quantizer.bitwidth
            activation_data_type = quantizer.data_type
            break
        for quantizer in self._get_param_quantizers(name_to_quantizer_dict):
            if quantizer.enabled:
                parameter_bw = quantizer.bitwidth
                parameter_data_type = quantizer.data_type
                break
        return (activation_bw, activation_data_type), (parameter_bw, parameter_data_type) 
[docs]    def set_quantizers_to_candidate(self,
                                    name_to_quantizer_dict: Dict,
                                    candidate: CANDIDATE_WITH_DTYPE) -> None:
        """
        Sets a quantizer group to a given candidate bitwidth
        :param name_to_quantizer_dict: Gets module from module name
        :param candidate: candidate with act and param bw and data types
        """
        if len(candidate) == 1:
            (activation_bw, activation_data_type), = candidate
            param_bw, param_data_type = None, None
        else:
            (activation_bw, activation_data_type), (param_bw, param_data_type) = candidate
        for quantizer in self._get_input_quantizers(name_to_quantizer_dict) +\
                         
self._get_output_quantizers(name_to_quantizer_dict):
            quantizer.bitwidth = activation_bw
            quantizer.data_type = activation_data_type
        if param_bw is not None:
            for quantizer in self._get_param_quantizers(name_to_quantizer_dict):
                quantizer.bitwidth = param_bw
                quantizer.data_type = param_data_type 
[docs]    def to_list(self) -> List[Tuple[str, str]]:
        """
        Converts quantizer group to a list
        :return: List containing input/output quantizers & weight quantizers
        """
        return list(itertools.chain(
            (("input", module_name) for module_name in self.input_quantizers),
            (("output", module_name) for module_name in self.output_quantizers),
            (("weight", module_name) for module_name in self.parameter_quantizers),
        )) 
[docs]    def get_active_quantizers(self, name_to_quantizer_dict):
        """ Find all active tensor quantizers associated with this quantizer group """
        quantizers = self._get_input_quantizers(name_to_quantizer_dict) +\
                     
self._get_output_quantizers(name_to_quantizer_dict) +\
                     
self._get_param_quantizers(name_to_quantizer_dict)
        return [quantizer for quantizer in quantizers if quantizer.enabled] 
    def _get_input_quantizers(self, name_to_quantizer_dict):
        result = []
        for quantizer_name in self.input_quantizers:
            out = quantizer_name.split("_input_quantizer_idx_")
            assert len(out) == 2
            module_name, quantizer_idx = out[0], int(out[1])
            module = name_to_quantizer_dict[module_name]
            result.append(module.input_quantizers[quantizer_idx])
        return result
    def _get_output_quantizers(self, name_to_quantizer_dict):
        result = []
        for module_name in self.output_quantizers:
            module = name_to_quantizer_dict[module_name]
            result += module.output_quantizers
        return result
    def _get_param_quantizers(self, name_to_quantizer_dict):
        result = []
        for module_name in self.parameter_quantizers:
            module = name_to_quantizer_dict[module_name]
            for _, param_quantizer in module.param_quantizers.items():
                result.append(param_quantizer)
        return result
 
def find_wrapper_module(op_name: str, module_name_to_quantizer_dict: Dict) -> Tuple[str, torch.nn.Module]:
    """
    Finds quantization (wrapping) module corresponding to the wrapper module's dotted name
    :param op_name: Dotted name of op as represented in connected graph
    :param module_name_to_quantizer_dict: Dict key: name of wrapped module value: quantization wrapper
    :return: Module name and the corresponding torch module in the sim
    """
    # pylint:disable = protected-access
    module_name = op_name[op_name.find('.') + 1:]
    if module_name in module_name_to_quantizer_dict:
        return module_name, module_name_to_quantizer_dict[module_name]
    # Else it is a functional op
    raise KeyError
def get_module_name_to_module_dict(sim: QuantizationSimModel) -> Dict:
    """
    Creates a dictionary of wrapped module's name to quantizer module
    :param sim: quantization sim
    :return: Dict key: name of wrapped module value: quantization wrapper
    """
    module_name_to_quantizer_dict = {}
    for name, ref_module in sim.model.named_modules():
        if isinstance(ref_module, QcQuantizeWrapper):
            module_name_to_quantizer_dict[name] = ref_module
    return module_name_to_quantizer_dict
ops_to_skip = ['view',
               'NumToTensor',
               'Split',
               CG_SPLIT,
               'PythonOp',
               'Tile',
               'transpose',
               'reshape',
               'flatten',
               'permute',
               'Permute', # tensor.transpose() results in Permute. Name obtained after mpp
               'Reshape', # Name obtained after MPP
               'ChannelShuffle', # Obtained without going thru mpp. torch.nn.ChannelShuffle fails mpp
               'TopK', # Name obtained after MPP
               'PixelShuffle', # Name obtained after MPP
               'Expand', # Name obtained after MPP. Reproduce using tensor.expand
               'Pad', # Name obtained after MPP
               'Slice', # Name obtained after MPP
               'Gather', # Name obtained after MPP
               'ScatterElements', # Name obtained after MPP
               'ReduceMin',  # Name obtained after MPP
               'ReduceMax', # Name obtained after MPP
               'Upsample', # Name obtained after MPP
               'RoIPool', # Name obtained after MPP
               'MaxPool', # Name obtained after MPP
               'Transpose' # Name obtained after MPP
               ]
ops_not_to_traverse = ['size']
def find_output_quantizer_groups(op: Op, parent_child_op_groups: Dict, map_for_skipped_ops: Dict):
    """
    Finds quantizer groups along the parent to child flow
    :param op: pytorch module
    :param parent_child_op_groups: parent child relationships in graph
    :param map_for_skipped_ops: map to find first skipped parents of skipped ops
    """
    output = op.output
    if output:
        consumers = output.consumers
        for consumer in consumers:
            if consumer.type in ops_not_to_traverse:
                continue
            dotted_name = op.dotted_name
            if op.dotted_name in map_for_skipped_ops:
                dotted_name = map_for_skipped_ops[op.dotted_name]
            if consumer.type in ops_to_skip:
                map_for_skipped_ops[consumer.dotted_name] = dotted_name
                find_output_quantizer_groups(consumer, parent_child_op_groups, map_for_skipped_ops)
            # If there is a one to one connection between quantizers
            else:
                parent_child_op_groups[dotted_name].append(consumer.dotted_name)
    else:
        if op.dotted_name in map_for_skipped_ops:
            parent_child_op_groups[map_for_skipped_ops[op.dotted_name]] = []
def find_op_groups(graph: ConnectedGraph) -> Dict:
    """
    Finds parent children relationship based on three rules
    1) If there is a direct connection between two ops, op1 and op2, then op1 is parent of op2 and they form a group
    2) If the input to an op (op1) is shared with another op (op2), the op producing the input (op0) is the parent, and op1 and op2 are the children
    :param graph: connected graph
    :return: Dict of parent (key) and children (value) relationship
    """
    parent_child_op_groups = defaultdict(list)
    map_for_skipped_ops = {}
    for op in graph.ordered_ops:
        # Add 1st op as child
        if not op.input_ops:
            parent_child_op_groups['input_ops'].append(op.dotted_name)
        # Add output op as child to put output of model as a quantizer group
        if op.output is None:
            parent_child_op_groups['output_ops'].append(op.dotted_name)
    for op in graph.get_all_ops().values():
        if op.type in ops_to_skip or op.type in ops_not_to_traverse:
            continue
        find_output_quantizer_groups(op, parent_child_op_groups, map_for_skipped_ops)
    return parent_child_op_groups
# This code is not currently called anywhere but it can be used to combine two ops who feed into an elementwise op
def find_input_quantizer_groups(graph, map_for_skipped_ops, parent_child_op_groups):
    """
    Combines two groups which share the same output
    :param graph: connected graph
    :param map_for_skipped_ops: map to find first skipped parents of skipped ops
    :param parent_child_op_groups: parent child relationships in graph
    """
    for op in graph.get_all_ops().values():
        inputs = op.input_ops
        if len(inputs) > 1:
            new_parents = set()
            new_children = set()
            for input_op in inputs:
                dotted_name = input_op.dotted_name
                if input_op.type in ops_to_skip:
                    dotted_name = map_for_skipped_ops[dotted_name]
                new_parents.add(dotted_name)
                if dotted_name in parent_child_op_groups:
                    for name in parent_child_op_groups[dotted_name]:
                        new_children.add(name)
                    del parent_child_op_groups[dotted_name]
            if len(new_parents) == 1:
                parent_child_op_groups[tuple(new_parents)[0]] = new_children
            else:
                parent_child_op_groups[tuple(new_parents)] = new_children
def get_input_and_param_quantizers(
        child: str, module_name_to_module_dict: Dict
) -> Tuple[Tuple[str, ...], Tuple[str, ...]]:
    """
    Adds child's input quantizer and param quantizer to quantizer group
    :param child: name of child
    :param module_name_to_module_dict: name to module ref dict
    :param quantizer_group: quantizer group
    """
    input_quantizers = []
    parameter_quantizers = []
    try:
        module_name, module = find_wrapper_module(child, module_name_to_module_dict)
    except KeyError:
        pass
    else:
        for idx, input_quantizer in enumerate(module.input_quantizers):
            if input_quantizer.enabled:
                input_quantizers.append(module_name + '_input_quantizer_idx_' + str(idx))
        for _, param_quantizer in module.param_quantizers.items():
            if param_quantizer.enabled:
                parameter_quantizers.append(module_name)
    return tuple(input_quantizers), tuple(parameter_quantizers)
# pylint: disable=too-many-branches, too-many-locals
def find_quantizer_group(sim: QuantizationSimModel) -> Tuple[Dict, List[QuantizerGroup]]:
    """
    Finds quantizer groups in a quantization sim model
    :param sim: Quantization sim
    :return: List of Quantizer groups
    """
    # Get connected graph from quantsim for model without wrappers
    connected_graph = sim.connected_graph
    if connected_graph is None:
        raise AssertionError('Aborting Auto Mixed Precision, connected graph needs to exist for Auto Mixed precision')
    quantizer_groups = []
    parent_child_op_groups = find_op_groups(connected_graph)
    module_name_to_module_dict = get_module_name_to_module_dict(sim)
    if 'input_ops' in parent_child_op_groups:
        for child in parent_child_op_groups['input_ops']:
            # Add one quantizer group for each input and it's weight param
            input_quantizers, parameter_quantizers = get_input_and_param_quantizers(child, module_name_to_module_dict)
            if input_quantizers or parameter_quantizers:
                quantizer_group = QuantizerGroup(
                    input_quantizers=input_quantizers,
                    parameter_quantizers=parameter_quantizers
                )
                quantizer_groups.append(quantizer_group)
                logger.debug('\n Quantizer Group Added: %s', quantizer_group)
    # Based on which quantizers are enabled, create a list of quantizer_groups
    for parents, children in parent_child_op_groups.items():
        input_quantizers = ()
        output_quantizers = ()
        parameter_quantizers = ()
        if parents in ['input_ops', 'output_ops']:
            continue
        if not isinstance(parents, tuple):
            parents = [parents]
        for parent in parents:
            try:
                module_name, module = find_wrapper_module(parent, module_name_to_module_dict)
            except KeyError:
                continue
            if module is not None:
                for output_quantizer in module.output_quantizers:
                    if output_quantizer.enabled:
                        output_quantizers += (module_name,)
        for child in children:
            input_q, param_q = get_input_and_param_quantizers(child, module_name_to_module_dict)
            input_quantizers += input_q
            parameter_quantizers += param_q
        # Don't add quantizer group if it is empty
        if input_quantizers or output_quantizers or parameter_quantizers:
            quantizer_group = QuantizerGroup(
                input_quantizers=input_quantizers,
                output_quantizers=output_quantizers,
                parameter_quantizers=parameter_quantizers
            )
            quantizer_groups.append(quantizer_group)
            logger.debug('\n Quantizer Group added: %s', quantizer_group)
    if 'output_ops' in parent_child_op_groups:
        for parent in parent_child_op_groups['output_ops']:
            # Add one quantizer group for each input and it's weight param
            try:
                module_name, module = find_wrapper_module(parent, module_name_to_module_dict)
            except KeyError:
                continue
            if module is not None:
                for output_quantizer in module.output_quantizers:
                    if output_quantizer.enabled:
                        quantizer_group = QuantizerGroup(
                            output_quantizers=(module_name,),
                        )
                        quantizer_groups.append(quantizer_group)
                        logger.debug('\n Quantizer Group added: %s', quantizer_group)
    return module_name_to_module_dict, quantizer_groups
def find_supported_candidates(quantizer_groups: List[QuantizerGroup],
                              amp_candidates: List[CANDIDATE_WITH_DTYPE],
                              supported_kernels: Dict,
                              module_name_to_module_dict: Dict,
                              use_all_amp_candidates: bool) -> Tuple[Dict, List]:
    """
    Computes 1. a list of supported candidates per Quantizer and 2. List of candidate options for max_candidate
    :param quantizer_groups: List of quantizer groups computed for the given model
    :param amp_candidates: List of candidates specified by the user to be used for the AMP algorithm
    :param supported_kernels: Dict of supported kernels for a given op/defaults specified in the config file
    :param module_name_to_module_dict: Dict mapping module name to module/quantizer
    :param use_all_amp_candidates: Boolean value representing whether the unsupported candidates in the
    "candidates" list need to be considered for creating the output lists. If set to True, all the AMP candidates are
    directly used for all the Quantizers, else the candidates per Quantizers are computed.
    """
    quantizers_with_supported_candidates = defaultdict(list)
    # pylint: disable=too-many-nested-blocks
    for quantizer_group in quantizer_groups:
        quantizers = sorted(set(itertools.chain(quantizer_group.get_input_quantizer_modules(),
                                                quantizer_group.output_quantizers,
                                                quantizer_group.parameter_quantizers)))
        # quantizers are now unique ops present in the given quantizer_group
        onnx_ops = defaultdict(list)
        for quantizer in quantizers:
            if quantizer not in module_name_to_module_dict:
                raise RuntimeError('module_name_to_module_dict does not contain an entry for the quantizer:',
                                   quantizer)
            # pylint: disable=protected-access
            module = module_name_to_module_dict[quantizer]._module_to_wrap.__class__.__name__
            if module in aimet_op_to_backend_op_name_map and aimet_op_to_backend_op_name_map[module] in supported_kernels:
                onnx_ops[quantizer] = [aimet_op_to_backend_op_name_map[module]]
            else:
                onnx_types = onnx_utils.map_torch_types_to_onnx.get(
                    type(module_name_to_module_dict[quantizer]._module_to_wrap), [])
                if not onnx_types:
                    logger.warning("No mapping found for %s in the torch to onnx op type mapping dictionary.",
                                   str(type(module_name_to_module_dict[quantizer]._module_to_wrap)))
                onnx_ops[quantizer] = onnx_types
                for onnx_type in onnx_types:
                    if onnx_type not in supported_kernels.keys():
                        if module in supported_kernels:
                            supported_kernels[onnx_type] = supported_kernels[module]
        supported_kernels_for_quantizers = get_supported_candidates_for_quantizers(quantizers,
                                                                                   onnx_ops,
                                                                                   supported_kernels,
                                                                                   amp_candidates,
                                                                                   use_all_amp_candidates)
        quantizers_with_supported_candidates[quantizer_group] = supported_kernels_for_quantizers.copy()
    max_candidate_options = compute_baseline_candidate_options(quantizers_with_supported_candidates, amp_candidates,
                                                               use_all_amp_candidates)
    return quantizers_with_supported_candidates, max_candidate_options