# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2022-2023, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
""" Quantsim for Keras """
from __future__ import annotations
import contextlib
from dataclasses import dataclass
import json
import os
from typing import Union, Dict, Tuple, Optional, List
import tensorflow as tf
from aimet_common import libpymo
from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_common.utils import AimetLogger, save_json_yaml
from aimet_common import quantsim
from aimet_common.quantsim import extract_global_quantizer_args
from aimet_tensorflow.keras.connectedgraph import ConnectedGraph
from aimet_tensorflow.keras.graphsearchtuils import GraphSearchUtils
from aimet_tensorflow.keras.quant_sim.qc_quantize_wrapper import QcQuantizeWrapper, QuantizerSettings
from aimet_tensorflow.keras.quant_sim.qc_mha_wrapper import QcQuantizableMultiHeadAttention
from aimet_tensorflow.keras.rnn.qc_quant_LSTM import QuantizedLSTM
from aimet_tensorflow.keras.quant_sim.tensor_quantizer import TensorQuantizer, ActivationTensorQuantizer, \
    ParamPerTensorQuantizer, StaticGridPerChannelQuantizer, ParamPerChannelQuantizer
from aimet_tensorflow.keras.quantsim_config.quantsim_config import QuantSimConfigurator, INPUT_QUANTIZERS, \
    OUTPUT_QUANTIZERS, PARAM_QUANTIZERS
from aimet_tensorflow.keras.utils.common import convert_h5_model_to_pb_model
from aimet_tensorflow.keras.defs import AxisHandling
import aimet_tensorflow.keras.utils.common as keras_common_utils
_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
unquantizable_modules = (tf.keras.layers.InputLayer, QcQuantizeWrapper)
substitutable_modules = {
    tf.keras.layers.MultiHeadAttention: QcQuantizableMultiHeadAttention,
    tf.keras.layers.LSTM: QuantizedLSTM
}
@dataclass
class QuantizationSimModelParams:
    """
    Data class that holds parameters for QuantizationSimModel. Used specifically to rebuild after converting to TF frozen pb
    """
    quant_scheme: Union[QuantScheme, str] = 'tf_enhanced'
    rounding_mode: str = 'nearest'
    default_output_bw: int = 8
    default_param_bw: int = 8
    in_place: bool = False
    config_file: str = None
    default_data_type: QuantizationDataType = QuantizationDataType.int
# pylint: disable=too-many-ancestors
# pylint: disable=too-many-instance-attributes
[docs]class QuantizationSimModel(tf.keras.Model):
    """
    Implements mechanism to add quantization simulations ops to a model. This allows for off-target simulation of
    inference accuracy. Also allows the model to be fine-tuned to counter the effects of quantization.
    """
    # pylint: disable=too-many-arguments
    # pylint: disable=unused-argument
    def __init__(self, model, quant_scheme: Union[QuantScheme, str] = 'tf_enhanced', rounding_mode: str = 'nearest',
                 default_output_bw: int = 8, default_param_bw: int = 8, in_place: bool = False,
                 config_file: str = None, default_data_type: QuantizationDataType = QuantizationDataType.int):
        """
        :param model: Model to quantize
        :param quant_scheme: Quantization Scheme, currently supported schemes are post_training_tf and
               post_training_tf_enhanced, defaults to post_training_tf_enhanced
        :param rounding_mode: The round scheme to used. One of: 'nearest' or 'stochastic', defaults to 'nearest'.
        :param default_output_bw: bitwidth to use for activation tensors, defaults to 8
        :param default_param_bw: bitwidth to use for parameter tensors, defaults to 8
        :param in_place: If True, then the given 'model' is modified in-place to add quant-sim nodes.
                Only suggested use of this option is when the user wants to avoid creating a copy of the model
        :param config_file: Path to a config file to use to specify rules for placing quant ops in the model
        :param default_data_type: Default data type to use for quantizing all layer parameters.
                                 Possible options are QuantizationDataType.int and QuantizationDataType.float.
                                 Note that the mode default_data_type=QuantizationDataType.float is only supported with
                                 default_output_bw=16 and default_param_bw=16
        """
        super().__init__()
        self._model_without_wrappers = model
        if not in_place:
            self._model_without_wrappers = tf.keras.models.clone_model(model)
            n_weights = len(self._model_without_wrappers.weights)
            self._model_without_wrappers.set_weights(model.get_weights()[:n_weights])
        self._layer_name_to_quant_wrapper = {}
        self._substituted_layer = {}  # to hold the substituted layers
        self._validate_model()
        self.connected_graph = ConnectedGraph(self._model_without_wrappers)
        self._quantsim_configurator = self._initialize_quantsim_configurator(quant_scheme, rounding_mode,
                                                                             default_output_bw, default_param_bw,
                                                                             default_data_type, config_file)
        self.quant_scheme = quant_scheme
        self._percentile_value = 100  # default percentile value
        self.per_channel_quantization_enabled = self._quantsim_configurator.per_channel_quantization_flag
        self.model = self._add_quantization_wrappers(quant_scheme, rounding_mode,
                                                     default_output_bw, default_param_bw, default_data_type)
        self.quant_args = extract_global_quantizer_args(quant_scheme, self._quantsim_configurator)
        self._params = QuantizationSimModelParams(quant_scheme, rounding_mode, default_output_bw, default_param_bw,
                                                  in_place, config_file, default_data_type)
    def _validate_model(self):
        """
        Check that model is appropriate for quantsim.
        """
        multiple_inbound_node_layers = []
        for layer in self._model_without_wrappers.layers:
            if len(layer.inbound_nodes) > 1:
                multiple_inbound_node_layers.append(layer.name)
        if multiple_inbound_node_layers:
            error_msg = (f'Layers with more than one inbound nodes are unsupported. This may occur if a layer is '
                         f'reused multiple times in the model definition.\n'
                         f'Layers with multiple inbound nodes: {multiple_inbound_node_layers}')
            _logger.error(error_msg)
            raise NotImplementedError(error_msg)
        sep_conv_found = self.check_separable_conv(self._model_without_wrappers)
        if sep_conv_found:
            # Raising an assertion error incase there's SeparableConv2D in the model because in this case we have two sets of weights: Depthwise
            # and Pointwise. For depthwise kernels, LAST TWO AXIS should be considered and for pointwise kernels LAST AXIS
            # should be considered, which is not handled here. Running model preparer beforehand will resolve this as there the
            # SeparableConv2D is splitted into two layers Depthwise and Pointwise seperately.
            raise AssertionError("SeparableConv2D found in the model. Please run model preparer before calling QuantizationSimModel")
    def check_separable_conv(self, model: tf.keras.models.Model | tf.keras.Sequential) -> bool:
        """
        Checks for SeparableConv2D layer in the model
        :param model: Keras Model
        :return: Boolean value, True if SeperableConv layer is found else False
        """
        for layer in model.layers:
            if isinstance(layer, tf.keras.Sequential):
                if self.check_separable_conv(layer):
                    return True
            elif isinstance(layer, tf.keras.layers.SeparableConv2D):
                return True
        return False
    def _get_quantizer_list(self) -> Tuple[List, List, List]:
        """
        Method to provide a list of input, output and parameter quantizers
        :return: Three lists containing input, paramater and output quantizers respectively
        """
        input_quantizers = []
        parameter_quantizers = []
        output_quantizers = []
        for wrapper in self.quant_wrappers():
            for quantizer in wrapper.input_quantizers:
                input_quantizers.append(quantizer)
            for quantizer in wrapper.param_quantizers:
                parameter_quantizers.append(quantizer)
            for quantizer in wrapper.output_quantizers:
                output_quantizers.append(quantizer)
        return input_quantizers, parameter_quantizers, output_quantizers
    def set_percentile_value(self, percentile_value: float):
        """
        Set the percentile value to be used while computing encodings for quantizers having percentile quant scheme.
        :param percentile_value: Percentile value to be set to
        """
        if percentile_value < 90 or percentile_value > 100:
            raise ValueError("Percentile value must be in range [90, 100]")
        self._percentile_value = percentile_value
        # Set the percentile value to the activation quantizers
        input_quantizers, _, output_quantizers = self._get_quantizer_list()
        for quantizer in input_quantizers + output_quantizers:
            if quantizer.quant_scheme == QuantScheme.post_training_percentile:
                quantizer.set_percentile_value(self._percentile_value)
    def _initialize_quantsim_configurator(self, quant_scheme: Union[QuantScheme, str], rounding_mode: str,
                                          default_output_bw: int, default_param_bw: int,
                                          default_data_type: QuantizationDataType = QuantizationDataType.int,
                                          config_file: str = None) -> QuantSimConfigurator:
        """
        Initialize quantsim configurator
        :param quant_scheme: Quantization Scheme
        :param rounding_mode: The round scheme to used
        :param default_output_bw: bitwidth to use for activation tensors
        :param default_param_bw: bitwidth to use for parameter tensors
        :param default_data_type: data type to use for the parameter tensors
        :param config_file: Path to a config file to use to specify rules for placing quant ops in the model
        :return: QuantSimConfigurator
        """
        return QuantSimConfigurator(self.connected_graph, quant_scheme, rounding_mode,
                                    default_output_bw, default_param_bw, default_data_type, config_file)
    def _add_quantization_wrappers(self, quant_scheme, rounding_mode,
                                   default_output_bw, default_param_bw, default_data_type):
        """
        Add quantization wrappers to the model and return a new model with the wrappers inserted.
        :param quant_scheme: Quantization scheme to use
        :param rounding_mode: Rounding mode to use
        :param default_output_bw: Default bitwidth for activation quantizers
        :param default_param_bw: Default bitwidth for param quantizers
        :param default_data_type: data type to use for param quantizers
        """
        def wrap_layer(layer) -> tf.keras.layers.Layer:
            """
            Function to wrap layers with QcQuantizeWrappers, used by keras clone_model()
            :param layer: Layer to wrap
            :return: Wrapped layer, or original layer if layer is not to be wrapped
            """
            if isinstance(layer, tuple(substitutable_modules.keys())):
                new_class = substitutable_modules[type(layer)]
                config = layer.get_config()
                config["copy_source_weights"] = layer.get_weights()
                if isinstance(layer, tf.keras.layers.LSTM):
                    if isinstance(self._model_without_wrappers, tf.keras.Sequential):
                        config["is_sequential_model"] = True
                    # pylint: disable=protected-access
                    if self._quantsim_configurator._layer_to_config_dict[layer]["is_input_quantized"]["setting"]:
                        config["is_input_quantized"] = True
                    config["quant_scheme"] = quant_scheme
                    config["rounding_mode"] = rounding_mode
                    config["default_output_bw"] = default_output_bw
                    config["default_param_bw"] = default_param_bw
                    config["default_data_type"] = default_data_type
                wrapped_layer = new_class.from_config(config)
                self._substituted_layer[layer] = wrapped_layer
                return wrapped_layer
            if isinstance(layer, tf.keras.Sequential):
                return tf.keras.models.clone_model(layer, clone_function=wrap_layer)
            if isinstance(layer, unquantizable_modules) or layer.submodules:
                return layer
            activation_quant_settings = QuantizerSettings(default_output_bw, default_data_type, rounding_mode,
                                                          quant_scheme, False, False, False)
            param_quant_settings = QuantizerSettings(default_param_bw, default_data_type, rounding_mode,
                                                     quant_scheme, False, False, False)
            input_quantizers, output_quantizers, param_quantizers = self._get_quantizers_by_layer(layer)
            wrapper = QcQuantizeWrapper(layer, activation_quant_settings, param_quant_settings,
                                        num_inputs=len(layer.inbound_nodes[0].keras_inputs),
                                        input_quantizers=input_quantizers,
                                        output_quantizers=output_quantizers,
                                        param_quantizers=param_quantizers,
                                        per_channel_quantization_enabled=self.per_channel_quantization_enabled)
            self._layer_name_to_quant_wrapper[layer.name] = wrapper
            return wrapper
        return tf.keras.models.clone_model(self._model_without_wrappers, clone_function=wrap_layer)
    def _get_quantizers_by_layer(self, layer: tf.keras.layers.Layer) -> Tuple[Optional[ActivationTensorQuantizer],
                                                                              Optional[ActivationTensorQuantizer],
                                                                              Union[ParamPerTensorQuantizer,
                                                                                    ParamPerChannelQuantizer]]:
        """
        Get input/output/param quantizers from quantizers dictionary or initialize quantizers if layer is not found
        :param layer: Target layer
        :return: tuple of input, output, param quantizers
        """
        quantizers_dict = self._quantsim_configurator.get_quantizers_dict(layer)
        if quantizers_dict is None:
            _logger.warning("%s not found in quantizers dict, will generate quantizers automatically", layer.name)
            input_quantizers = None
            output_quantizers = None
            param_quantizers = None
        else:
            input_quantizers = quantizers_dict.get(INPUT_QUANTIZERS)
            output_quantizers = quantizers_dict.get(OUTPUT_QUANTIZERS)
            param_quantizers = quantizers_dict.get(PARAM_QUANTIZERS)
        return input_quantizers, output_quantizers, param_quantizers
    @staticmethod
    def _quantizer_to_name_tuple(quantizers: List[TensorQuantizer]) -> Tuple[Optional[List[str]]]:
        """
        Converts a list of quantizers to a tuple of quantizer names
        :param quantizers: quantizers
        :return: tuple of quantizer names
        """
        quant_list = []
        if not quantizers:
            return None
        for quantizer in quantizers:
            quant_list.append(quantizer.name)
        return tuple(quant_list)
    def get_quantizer_name_by_layer(self, layer: tf.keras.layers.Layer) -> Tuple[Optional[List[str]],
                                                                                 Optional[List[str]],
                                                                                 Optional[List[str]]]:
        """
        Get the names of input, output and param quantizers
        :param layer: the keras layer
        :return: Tuple of quantizer names
        """
        input_quantizers, output_quantizers, param_quantizers = self._get_quantizers_by_layer(layer)
        output_quantizers_names = self._quantizer_to_name_tuple(output_quantizers)
        input_quantizers_names = self._quantizer_to_name_tuple(input_quantizers)
        parameter_quantizers_names = self._quantizer_to_name_tuple(param_quantizers)
        return input_quantizers_names, output_quantizers_names, parameter_quantizers_names
    def _disable_quantizers_in_folded_batchnorm(self):
        """
        Disable input/output/param quantizers if layer is folded batch normalization
        """
        for quantsim_wrapper in self._layer_name_to_quant_wrapper.values():
            if GraphSearchUtils.is_folded_batch_normalization(quantsim_wrapper.original_layer):
                for q in quantsim_wrapper.input_quantizers:
                    q.disable()
                for q in quantsim_wrapper.output_quantizers:
                    q.disable()
                for q in quantsim_wrapper.param_quantizers:
                    q.disable()
    @staticmethod
    def _get_encoding_dict_for_quantizer(quantizer: TensorQuantizer) -> Union[List[Dict[str, Union[str, int, float]]],
                                                                              Dict[str, Union[str, int, float]]]:
        """
        Get encoding dict for a tensor quantizer.
        :param quantizer: Quantizer to get encoding info from
        :return: Dictionary or List of dictionaries containing encodings info for the tensor quantizer
        """
        if not isinstance(quantizer, ParamPerChannelQuantizer) or quantizer.data_type == QuantizationDataType.float:
            quantizer_encodings = [quantizer.encoding]
        else:
            quantizer_encodings = quantizer.encoding
        return [
            {
                'min': encoding.min,
                'max': encoding.max,
                'scale': encoding.delta,
                'offset': int(encoding.offset),
                'bitwidth': encoding.bw,
                'is_symmetric': str(quantizer.is_symmetric),
                'dtype': 'int'
            } if quantizer.data_type == QuantizationDataType.int
            else {'dtype': 'float', 'bitwidth': int(quantizer.bitwidth)}
            for encoding in quantizer_encodings
        ]
    def get_encodings_dict(self) -> Dict[str, Union[str, Dict]]:
        """
        Get encodings dict containing all activation and parameter encodings info in the model
        :return: Dictionary containing all activation and parameter encodings info in the model
        """
        # pylint: disable=protected-access, too-many-branches
        model_input_tensor_names = [inp.name for inp in self.model.inputs]
        activation_encodings = {}
        param_encodings = {}
        for wrapper in self.quant_wrappers():
            for idx, input_quantizer in enumerate(wrapper.input_quantizers):
                if input_quantizer.is_encoding_valid() or input_quantizer.data_type == QuantizationDataType.float:
                    # because dense layers in quantizable MHA are not explicitly sublayers, they don't have their
                    # inbound_nodes parameter populated, so the name of the quantizer is used instead
                    if not wrapper._layer_to_wrap.inbound_nodes:
                        tensor_name = wrapper.name + "/" + input_quantizer.name + ":0"
                    else:
                        tensor_name = wrapper._layer_to_wrap.inbound_nodes[0].keras_inputs[idx].name
                    encoding_dict = self._get_encoding_dict_for_quantizer(input_quantizer)
                    if tensor_name in model_input_tensor_names:
                        tensor_name += ":0"
                    activation_encodings[tensor_name] = encoding_dict
            for idx, param_quantizer in enumerate(wrapper.param_quantizers):
                if param_quantizer.is_encoding_valid() or param_quantizer.data_type == QuantizationDataType.float:
                    param_name = wrapper._layer_to_wrap.weights[idx].name
                    encoding_dict = self._get_encoding_dict_for_quantizer(param_quantizer)
                    param_encodings[param_name] = encoding_dict
            for idx, output_quantizer in enumerate(wrapper.output_quantizers):
                if output_quantizer.is_encoding_valid() or output_quantizer.data_type == QuantizationDataType.float:
                    # because dense layers in quantizable MHA are not explicitly sublayers, they don't have their
                    # inbound_nodes parameter populated, so the name of the quantizer is used instead
                    if not wrapper._layer_to_wrap.inbound_nodes:
                        tensor_name = wrapper.name + ":0"
                    elif isinstance(wrapper._layer_to_wrap.output, List):
                        tensor_name = wrapper._layer_to_wrap.output[idx].name
                    else:
                        tensor_name = wrapper._layer_to_wrap.output.name
                    encoding_dict = self._get_encoding_dict_for_quantizer(output_quantizer)
                    activation_encodings[tensor_name] = encoding_dict
        return {
            'version': quantsim.encoding_version,
            'activation_encodings': activation_encodings,
            'param_encodings': param_encodings,
            'quantizer_args': self.quant_args if hasattr(self, "quant_args") else {}
        }
[docs]    def compute_encodings(self, forward_pass_callback, forward_pass_callback_args):
        """
        Computes encodings for all quantization sim nodes in the model.
        :param forward_pass_callback: A callback function that is expected to runs forward passes on a model.
               This callback function should use representative data for the forward pass, so the calculated
               encodings work for all data samples.
        :param forward_pass_callback_args: These argument(s) are passed to the forward_pass_callback as-is. Up to
               the user to determine the type of this parameter. E.g. could be simply an integer representing the number
               of data samples to use. Or could be a tuple of parameters or an object representing something more
               complex.
        """
        ops_with_invalid_encodings = []
        self._compute_and_set_parameter_encodings(ops_with_invalid_encodings)
        self._set_op_mode_parameters(libpymo.TensorQuantizerOpMode.quantizeDequantize)
        forward_pass_callback(self.model, forward_pass_callback_args)
        for quant_wrapper in self.quant_wrappers():
            quant_wrapper.compute_encoding(ops_with_invalid_encodings)
        op_mode = self._param_op_mode_after_analysis(self.quant_scheme)
        self._set_op_mode_parameters(op_mode)
        if ops_with_invalid_encodings:
            _logger.info('The following quantizers did not have valid encodings and have been set to passThrough mode: '
                         '%s', ops_with_invalid_encodings)
            _logger.info('This can be due to the quantizers not having been evaluated during the forward pass in '
                         'compute encodings. Evaluation is required to collect statistics needed to compute valid '
                         'encodings.\n'
                         'As a result, the quantizers have been set to passThrough mode, meaning no quantization noise '
                         'will be simulated for these ops if they are evaluated in the future.\n'
                         'If this is not desired, amend the forward pass to evaluate tensors which require these ops '
                         'to be evaluated, and recompute encodings.') 
    def _set_op_mode_parameters(self, op_mode: libpymo.TensorQuantizerOpMode):
        """
        Sets quant mode for parameters and if the encodings are invalid, then adds those wrappers
        to wrappers_with_invalid_encodings
        :param op_mode: Quant mode to set to
        """
        for quantizer_info in self.quant_wrappers():
            for param_quantizer in quantizer_info.param_quantizers:
                if param_quantizer.is_enabled():
                    param_quantizer.quant_mode = op_mode
    @staticmethod
    @contextlib.contextmanager
    def _set_encoding_version_to_0_6_1():
        assert quantsim.encoding_version in {'0.6.1', '1.0.0'}
        if quantsim.encoding_version == '1.0.0':
            _logger.info('Exporting to encoding version 1.0.0 is not yet supported. Exporting using version 0.6.1 '
                         'instead.')
        old_encoding_version = quantsim.encoding_version
        quantsim.encoding_version = '0.6.1'
        yield
        quantsim.encoding_version = old_encoding_version
[docs]    def export(self, path, filename_prefix, custom_objects=None, convert_to_pb=True):
        """
        This method exports out the quant-sim model so it is ready to be run on-target.
        Specifically, the following are saved
        1. The sim-model is exported to a regular Keras model without any simulation ops
        2. The quantization encodings are exported to a separate JSON-formatted file that can
           then be imported by the on-target runtime (if desired)
        :param path: path where to store model pth and encodings
        :param filename_prefix: Prefix to use for filenames of the model pth and encodings files
        :param custom_objects: If there are custom objects to load, Keras needs a dict of them to map them
        """
        with self._set_encoding_version_to_0_6_1():
            model_path = os.path.join(path, filename_prefix)
            #TF Version 2.4 has bug i.e. save() in tf format don't work for unrolled LSTM.
            for layer in self._model_without_wrappers.layers:
                if isinstance(layer, tf.keras.layers.LSTM):
                    break
            else:
                self._model_without_wrappers.save(model_path)
            self._model_without_wrappers.save(model_path + '.h5', save_format='h5')
            # Conversion of saved h5 model to pb model for consumption by SNPE/QNN
            try:
                if convert_to_pb:
                    convert_h5_model_to_pb_model(f'{model_path}.h5', custom_objects=custom_objects)
            except ValueError:
                _logger.error("Could not convert h5 to frozen pb. "
                              "Please call export() again with custom_objects defined.")
                raise
            finally:
                encodings_dict = self.get_encodings_dict()
                encoding_file_path = os.path.join(path, filename_prefix + '.encodings')
                save_json_yaml(encoding_file_path, encodings_dict) 
    def _compute_and_set_parameter_encodings(self, ops_with_invalid_encodings: List):
        # pylint: disable=too-many-nested-blocks
        for quantizer_wrapper in self.quant_wrappers():
            for idx, param_quantizer in enumerate(quantizer_wrapper.param_quantizers):
                if param_quantizer.is_enabled() and param_quantizer.data_type == QuantizationDataType.int:
                    # 0th input to our quant wrapper is the tensor being quantized
                    weight_tensor = quantizer_wrapper.original_layer.get_weights()[idx]
                    # Per-channel
                    if isinstance(param_quantizer, StaticGridPerChannelQuantizer):
                        for index, tensor_quantizer in enumerate(param_quantizer.tensor_quantizer):
                            if param_quantizer.axis_handling == AxisHandling.LAST_TWO_AXES.value:
                                last_two_axes_combined_shape = list(weight_tensor.shape[:-2]) + [-1]
                                channel_slice = weight_tensor.reshape(*last_two_axes_combined_shape)
                                channel_slice = channel_slice.take(index, channel_slice.ndim - 1)
                            elif isinstance(quantizer_wrapper.original_layer, tf.keras.layers.Conv2DTranspose):
                                if weight_tensor.ndim == 4:
                                    channel_slice = weight_tensor.take(index, weight_tensor.ndim - 2)
                                else:
                                    # For bias in Transpose layers
                                    channel_slice = weight_tensor.take(index, weight_tensor.ndim - 1)
                            else:
                                channel_slice = weight_tensor.take(index, weight_tensor.ndim - 1)
                            tensor_quantizer.updateStats(channel_slice, False)
                    # Per-tensor
                    else:
                        tensor_quantizer = param_quantizer.tensor_quantizer
                        tensor_quantizer.updateStats(weight_tensor, False)
                    param_quantizer.compute_encoding(ops_with_invalid_encodings)
    def set_and_freeze_param_encodings(self, encoding_path: str):
        """
        Set and freeze parameter encodings from encodings JSON file
        :param encoding_path: path from where to load parameter encodings file
        """
        # Load parameter encodings file
        with open(encoding_path) as json_file:
            param_encodings = json.load(json_file)
        for quant_wrapper in self.quant_wrappers():
            quant_wrapper.set_and_freeze_param_encoding(param_encodings)
    # pylint: disable=too-many-nested-blocks
    def load_encodings_to_sim(self, encoding_file_path: str):
        """
        Loads the saved encodings to quant sim model
        :param encoding_file_path: path from where to load encodings file
        :return:
        """
        # pylint: disable=protected-access, too-many-branches, too-many-locals, too-many-statements
        # Load encodings file
        with open(encoding_file_path) as json_file:
            encodings = json.load(json_file)
        param_encodings = encodings['param_encodings']
        activation_encodings = encodings['activation_encodings']
        model_input_tensor_names = [inp.name for inp in self.model.inputs]
        for wrapper in self.quant_wrappers():
            for idx, input_quantizer in enumerate(wrapper.input_quantizers):
                # because dense layers in quantizable MHA and RNN are not explicitly sublayers, they don't have their
                # inbound_nodes parameter populated, so the name of the quantizer is used instead
                if not wrapper._layer_to_wrap.inbound_nodes:
                    tensor_name = wrapper.name + "/" + input_quantizer.name + ":0"
                else:
                    tensor_name = wrapper._layer_to_wrap.inbound_nodes[0].keras_inputs[idx].name
                if tensor_name in model_input_tensor_names:
                    tensor_name += ":0"
                if tensor_name in activation_encodings:
                    if not input_quantizer.is_enabled():
                        _logger.info("Not loading encodings for quantizer: %s as it is disabled", tensor_name)
                        continue
                    encoding_dict = activation_encodings[tensor_name][0]
                    if encoding_dict['dtype'] == 'int':
                        encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(encoding_dict)
                        input_quantizer.tensor_quantizer.isEncodingValid = True
                        input_quantizer.set_quantizer_encodings(encoding.bw, is_symmetric, encoding,
                                                                libpymo.TensorQuantizerOpMode.quantizeDequantize)
                        _logger.info("Setting encodings for : %s", tensor_name)
                    elif encoding_dict['dtype'] == 'float':
                        input_quantizer.data_type = QuantizationDataType.float
                        input_quantizer.bitwidth = encoding_dict['bitwidth']
                        _logger.info("Setting quantizer dtype to float for : %s", tensor_name)
                    else:
                        raise RuntimeError("Unrecognized dtype %s for: %s" % (encoding_dict['dtype'], tensor_name))
                else:
                    if input_quantizer.is_enabled():
                        input_quantizer.disable()
                        _logger.info("Encoding for quantizer: %s is not present thus disabling it.", tensor_name)
            for idx, param_quantizer in enumerate(wrapper.param_quantizers):
                param_name = wrapper._layer_to_wrap.weights[idx].name
                if param_name in param_encodings:
                    if not param_quantizer.is_enabled():
                        _logger.info("Not loading encodings for parameter: %s as quantizer is disabled", param_name)
                        continue
                    if isinstance(param_quantizer, StaticGridPerChannelQuantizer):
                        if param_encodings[param_name][0]['dtype'] == 'float':
                            wrapper.param_quantizers[idx] = ParamPerTensorQuantizer(layer=param_quantizer._original_layer,
                                                                                    name=param_quantizer.name,
                                                                                    quant_scheme=param_quantizer._quant_scheme,
                                                                                    round_mode='nearest',
                                                                                    bitwidth=param_encodings[param_name][0]['bitwidth'],
                                                                                    data_type=QuantizationDataType.float,
                                                                                    is_symmetric=param_quantizer.is_symmetric,
                                                                                    use_strict_symmetric=param_quantizer.use_strict_symmetric,
                                                                                    use_unsigned_symmetric=param_quantizer.use_unsigned_symmetric,
                                                                                    enabled=False)
                        else:
                            encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(
                                param_encodings[param_name])
                            for tensor_quantizer in param_quantizer.tensor_quantizer:
                                tensor_quantizer.isEncodingValid = True
                            bw = encoding[0].bw
                            param_quantizer.set_quantizer_encodings(bw, is_symmetric, encoding,
                                                                    libpymo.TensorQuantizerOpMode.oneShotQuantizeDequantize)
                        _logger.info("Setting encodings for : %s", param_name)
                    else:
                        encoding_dict = param_encodings[param_name][0]
                        if encoding_dict['dtype'] == 'int':
                            encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(encoding_dict)
                            param_quantizer.tensor_quantizer.isEncodingValid = True
                            bw = encoding.bw
                            param_quantizer.set_quantizer_encodings(bw, is_symmetric, encoding,
                                                                    libpymo.TensorQuantizerOpMode.oneShotQuantizeDequantize)
                            _logger.info("Setting encodings for : %s", param_name)
                        elif encoding_dict['dtype'] == 'float':
                            param_quantizer.data_type = QuantizationDataType.float
                            param_quantizer.bitwidth = encoding_dict['bitwidth']
                            _logger.info("Setting quantizer to float for : %s", param_name)
                        else:
                            raise RuntimeError("Unrecognized dtype %s for: %s" % (encoding_dict['dtype'], tensor_name))
                else:
                    if param_quantizer.is_enabled():
                        param_quantizer.disable()
                        _logger.info("Encoding for parameter: %s not present thus disabling this quantizer.",
                                     param_name)
            # Loading encodings means that compute encodings was called. Therefore, these two lines set the correct
            # op mode for the correct quant scheme and if the quantization was per channel or not.
            op_mode = self._param_op_mode_after_analysis(self.quant_scheme)
            self._set_op_mode_parameters(op_mode)
            for idx, output_quantizer in enumerate(wrapper.output_quantizers):
                # because dense layers in quantizable MHA are not explicitly sublayers, they don't have their
                # inbound_nodes parameter populated, so the name of the quantizer is used instead
                if not wrapper._layer_to_wrap.inbound_nodes:
                    tensor_names = [wrapper.name + ":0"]
                else:
                    # There can be multiple outputs if there is a
                    # `tf.split` in the model.
                    if isinstance(wrapper._layer_to_wrap.output, list):
                        tensor_names = [
                            output.name
                            for output in wrapper._layer_to_wrap.output
                        ]
                    else:
                        tensor_names = [wrapper._layer_to_wrap.output.name]
                for tensor_name in tensor_names:
                    if tensor_name in activation_encodings:
                        if not output_quantizer.is_enabled():
                            _logger.info("Not loading encodings for quantizer: %s as it is disabled", tensor_name)
                            continue
                        encoding_dict = activation_encodings[tensor_name][0]
                        if encoding_dict['dtype'] == 'int':
                            encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(encoding_dict)
                            output_quantizer.tensor_quantizer.isEncodingValid = True
                            output_quantizer.set_quantizer_encodings(encoding.bw, is_symmetric, encoding,
                                                                     libpymo.TensorQuantizerOpMode.quantizeDequantize)
                            _logger.info("Setting encodings for : %s", tensor_name)
                        elif encoding_dict['dtype'] == 'float':
                            output_quantizer.data_type = QuantizationDataType.float
                            output_quantizer.bitwidth = encoding_dict['bitwidth']
                            _logger.info("Setting quantizer dtype to float for : %s", tensor_name)
                        else:
                            raise RuntimeError("Unrecognized dtype %s for: %s" % (encoding_dict['dtype'], tensor_name))
                    else:
                        if output_quantizer.is_enabled():
                            output_quantizer.disable()
                            _logger.info("Encoding for quantizer: %s is not present thus disabling it.", tensor_name)
    def _param_op_mode_after_analysis(self, quant_scheme) -> libpymo.TensorQuantizerOpMode:
        """
        Returns quant mode to use for parameters after encodings have been computed
        :param quant_scheme: Quantization scheme to use
        :return: Quant mode to use
        """
        if quant_scheme in [QuantScheme.training_range_learning_with_tf_init,
                            QuantScheme.training_range_learning_with_tf_enhanced_init] \
                
or self.per_channel_quantization_enabled:
            return libpymo.TensorQuantizerOpMode.quantizeDequantize
        return libpymo.TensorQuantizerOpMode.oneShotQuantizeDequantize
    def quant_wrappers(self):
        """
        Generator for yielding all quantization wrappers
        """
        for layer in self.model.layers:
            if isinstance(layer, QcQuantizeWrapper):
                yield layer
            if isinstance(layer, tuple(substitutable_modules.values())):
                yield from layer.quant_wrappers()
            # For Getting Quantizers from Sequantial Block
            if isinstance(layer, tf.keras.Sequential):
                yield from quant_wrappers_for_sequential_block(layer)
    def get_quant_wrapper_for_layer_name(self, layer_name: str) -> QcQuantizeWrapper:
        """
        Return qc quant wrapper corresponding to a layer name
        :param layer_name: Layer name to get quantize wrapper for
        :return: Qc quant wrapper corresponding to a layer name
        """
        return self._layer_name_to_quant_wrapper.get(layer_name)
    # pylint: disable=too-many-locals
    def _fill_missing_encoding_min_max_gradients(self, gradients: list):
        """
        Computes the encoding min/max gradients and populates the gradients list
        :param gradients: gradients computed using GradientTape(gradients for encoding min/max will be `None`)
        """
        def _find_weight_in_layer(weight_name: str, model_layer: tf.keras.layers.Layer):
            for weight in model_layer.weights:
                if weight.name.split(":")[0] == weight_name:
                    return weight
            return None
        # Mapping used to get the gradients of weights(kernel, bias etc)
        weight_name_to_gradient = dict(zip([weight.name.split(":")[0] for weight in self.model.trainable_weights],
                                           gradients))
        # Mapping used to get index of encoding min/max gradients (which would be `None`) and fill them
        weight_name_to_index = dict(zip([weight.name for weight in self.model.trainable_weights],
                                        range(len(self.model.trainable_weights))))
        # Only process layers where 'param_quantizers' is defined (i.e. QcQuantizeWrapper layers)
        for layer in filter(lambda _layer: hasattr(_layer, 'param_quantizers'), self.model.layers):
            for param_quantizer in layer.param_quantizers:
                if param_quantizer.name in weight_name_to_gradient:
                    # Value of weight associated with this param quantizer
                    weight_tensor = _find_weight_in_layer(param_quantizer.name, layer.original_layer)
                    # Gradients of the weights
                    grad = weight_name_to_gradient[param_quantizer.name]
                    # Using the weights and it's gradients, compute gradients for encoding min/max
                    dloss_by_dmin, dloss_by_dmax = param_quantizer.get_gradients_for_encoding_min_max(weight_tensor,
                                                                                                      grad)
                    enc_min_index = weight_name_to_index[param_quantizer.encoding_min.name]
                    enc_max_index = weight_name_to_index[param_quantizer.encoding_max.name]
                    gradients[enc_min_index] = dloss_by_dmin
                    gradients[enc_max_index] = dloss_by_dmax
        # TODO: Remove this logic once this has been resolved in QNN/SNPE
        # Go through activation quantizers (Input/Output) and set any ReLU's encoding min to 0
        relu_quantize_wrappers = [
            _layer for _layer in self.model.layers
            if isinstance(_layer, QcQuantizeWrapper) and isinstance(_layer.original_layer, tf.keras.layers.ReLU)
        ]
        def _set_encoding_min_grad_to_None(quantizer):
            enc_min_index = weight_name_to_index[quantizer.encoding_min.name]
            gradients[enc_min_index] = None
        for relu_quantizer in relu_quantize_wrappers:
            for output_quantizer in relu_quantizer.output_quantizers:
                _set_encoding_min_grad_to_None(output_quantizer)
    # pylint: disable=useless-super-delegation
    def get_config(self):
        return super().get_config()
    def call(self, inputs, training=None, mask=None):
        return self.model.call(inputs, training, mask)
    def train_step(self, data):
        """
        Custom training loop, equivalent to overriding `keras.Model.fit` function
        Reference: https://keras.io/guides/customizing_what_happens_in_fit/
        Only relevant when using range-learning, otherwise equivalent to `keras.Model.fit`
        Param quantizers are disconnected in the op graph of the wrapped model
        Because of this, the gradients are not computed for encoding min/max(when range learning is enabled)
        This custom train_step function computes the missing gradients for encoding min/max of param quantizers
        """
        x, y = data
        with tf.GradientTape() as tape:
            predictions = self(x, training=True)
            loss = self.compiled_loss(y, predictions)
        gradients = tape.gradient(loss, self.model.trainable_weights)
        # Manually compute missing gradients for encoding min/max when using range learning
        if self.quant_scheme in [QuantScheme.training_range_learning_with_tf_init,
                                 QuantScheme.training_range_learning_with_tf_enhanced_init]:
            self._fill_missing_encoding_min_max_gradients(gradients)
        gradients_to_apply = [(gradient, weight) for gradient, weight in zip(gradients, self.model.trainable_weights)
                              if gradient is not None]
        self.optimizer.apply_gradients(gradients_to_apply)
        self.compiled_metrics.update_state(y, predictions)
        return {m.name: m.result() for m in self.metrics} 
def quant_wrappers_for_sequential_block(seq_block: tf.keras.Sequential):
    """
        Generator for yielding all quantization wrappers for a Sequantial Block
    """
    for layer in seq_block.layers:
        if isinstance(layer, QcQuantizeWrapper):
            yield layer
        if isinstance(layer, tuple(substitutable_modules.values())):
            yield from layer.quant_wrappers()
        # in cases of nested Sequential Block
        if isinstance(layer, tf.keras.Sequential):
            yield from quant_wrappers_for_sequential_block(layer)