# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2023, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: skip-file
""" Top level API for Adaptive Rounding - Post-Training Quantization (PTQ) """
import copy
import os
import tempfile
import json
from typing import Tuple, Dict, List, Callable
from onnx import onnx_pb
from onnxruntime.quantization.onnx_quantizer import ONNXModel
from tqdm import tqdm
# Import AIMET specific modules
from aimet_common import quantsim
from aimet_common.utils import AimetLogger
from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_onnx.adaround.adaround_loss import AdaroundHyperParameters
from aimet_onnx.adaround.adaround_tensor_quantizer import AdaroundTensorQuantizer
from aimet_onnx.quantsim import QuantizationSimModel
from aimet_onnx.qc_quantize_op import OpMode
from aimet_onnx.meta.utils import get_module_act_func_pair, get_ordered_ops
from aimet_onnx import utils
from aimet_onnx.adaround.adaround_optimizer import AdaroundOptimizer
from aimet_onnx.adaround.utils import ModelData, ModuleInfo
logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
# The following modules with weights are supported by Adaround
AdaroundSupportedModules = ['Conv', 'ConvTranspose', 'MatMul', 'Gemm']
[docs]
class AdaroundParameters:
    """
    Configuration parameters for Adaround
    """
    def __init__(self, data_loader, num_batches: int,
                 default_num_iterations: int = None, default_reg_param: float = 0.01,
                 default_beta_range: Tuple = (20, 2), default_warm_start: float = 0.2,
                 forward_fn: Callable = None, forward_pass_callback_args = None):
        """
        :param data_loader: Data loader
        :param num_batches: Number of batches to be used for Adaround.
         A commonly recommended value for this parameter is the smaller value among (1) len(data_loader) and (2) ceil(2000/batch_size)
        :param default_num_iterations: Number of iterations to adaround each layer.
         The default value is 10K for models with 8- or higher bit weights, and 15K for models with lower than 8 bit weights.
        :param default_reg_param: Regularization parameter, trading off between rounding loss vs reconstruction loss.
         Default 0.01
        :param default_beta_range: Start and stop beta parameter for annealing of rounding loss (start_beta, end_beta).
         Default (20, 2)
        :param default_warm_start: warm up period, during which rounding loss has zero effect. Default 20% (0.2)
        :param forward_fn: Function to compute encodings for sim
        :param forward_pass_callback_args: These argument(s) are passed to the forward_pass_callback as-is. Up to
            the user to determine the type of this parameter. E.g. could be simply an integer representing the number
            of data samples to use. Or could be a tuple of parameters or an object representing something more complex.
            If set to None, forward_pass_callback will be invoked with no parameters.
        """
        if len(data_loader) < num_batches:
            raise ValueError(f'Can not fetch {num_batches} batches from '
                             f'a data loader of length {len(data_loader)}.')
        self.data_loader = data_loader
        self.num_batches = num_batches
        self.num_iterations = default_num_iterations
        self.reg_param = default_reg_param
        self.beta_range = default_beta_range
        self.warm_start = default_warm_start
        self.forward_fn = forward_fn
        self.forward_pass_callback_args = forward_pass_callback_args 
class Adaround:
    """
    Weight-rounding mechanism for Post Training Quantization (PTQ)
    """
    @classmethod
    def apply_adaround(cls, model: onnx_pb.ModelProto, params: AdaroundParameters,
                       path: str, filename_prefix: str, default_param_bw: int = 4,
                       param_bw_override_list: List[Tuple[str, int]] = None,
                       ignore_quant_ops_list: List[str] = None,
                       default_quant_scheme: QuantScheme = QuantScheme.post_training_tf_enhanced,
                       default_config_file: str = None, use_cuda: bool = True, device: int = 0,
                       user_onnx_libs: List[str] = None) -> onnx_pb.ModelProto:
        """
        Returns model with optimized weight rounding of every module (Conv and Linear) and also saves the
        corresponding quantization encodings to a separate JSON-formatted file that can then be imported by
        QuantSim for inference or QAT
        :param model: Model to Adaround
        :param params: Parameters for Adaround
        :param path: path where to store parameter encodings
        :param filename_prefix: Prefix to use for filename of the encodings file
        :param default_param_bw: Default bitwidth (4-31) to use for quantizing layer parameters
        :param param_bw_override_list: List of Tuples. Each Tuple is a param name and the corresponding parameter bitwidth
                                       to be used for that param.
        :param ignore_quant_ops_list: Ops listed here are skipped during quantization needed for AdaRounding. Do not
                                      specify Conv and Linear modules in this list. Doing so, will affect accuracy.
        :param default_quant_scheme: Quantization scheme. Supported options are using Quant Scheme Enum
                                    QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced
        :param default_config_file: Default configuration file for model quantizers
        :param use_cuda: If we should use cuda
        :param device: CUDA device ID
        :param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries
        :return: Model with Adarounded weights and saves corresponding parameter encodings JSON file at provided path
        """
        # pylint: disable=too-many-arguments
        # Create Quant sim with given parameters
        if not isinstance(model, ONNXModel):
            model = ONNXModel(model)
        quant_sim = QuantizationSimModel(copy.deepcopy(model), quant_scheme=default_quant_scheme,
                                         default_param_bw=default_param_bw,
                                         config_file=default_config_file,
                                         user_onnx_libs=user_onnx_libs,
                                         use_cuda=use_cuda)
        # For the params in the param_bw_override_list, override the default parameter bitwidths in the QuantSim
        if param_bw_override_list:
            cls._override_param_bitwidth(quant_sim, param_bw_override_list)
        if ignore_quant_ops_list:
            cls._exclude_modules(quant_sim, ignore_quant_ops_list)
        # Compute only param encodings
        cls._compute_param_encodings(quant_sim, params)
        return cls._apply_adaround(quant_sim, model, params, path, filename_prefix, use_cuda, device, user_onnx_libs)
    @classmethod
    def _apply_adaround(cls, quant_sim: QuantizationSimModel, model: onnx_pb.ModelProto, params: AdaroundParameters,
                        path: str, filename_prefix: str, use_cuda: bool = True, device: int = 0,
                        user_onnx_libs: List[str] = None) -> onnx_pb.ModelProto:
        """
        Returns model with optimized weight rounding of every module (Conv and Linear) and also saves the
        corresponding quantization encodings to a separate JSON-formatted file that can then be imported by
        QuantSim for inference or QAT
        :param quant_sim: QuantizationSimModel object to optimize weight rounding.
                          The activation quantizers are expected to have been disabled.
        :param model: Original fp32 model from which quant_sim was created.
        :param params: Parameters for Adaround
        :param path: path where to store parameter encodings
        :param filename_prefix: Prefix to use for filename of the encodings file
        :param use_cuda: If we should use cuda
        :param device: CUDA device ID
        :param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries
        :return: Model with Adarounded weights and saves corresponding parameter encodings JSON file at provided path
        """
        # Sanity check: All the input/output quantizers should be disabled
        for quantizer_name in quant_sim.activation_names:
            assert not quant_sim.qc_quantize_op_dict[quantizer_name].enabled
        # Get the module - activation function pair using ConnectedGraph
        module_act_func_pair = get_module_act_func_pair(model)
        cls._adaround_model(model, quant_sim, module_act_func_pair, params, use_cuda, device, user_onnx_libs)
        # Export quantization encodings to JSON-formatted file
        cls._export_encodings_to_json(path, filename_prefix, quant_sim)
        quant_sim.remove_quantization_nodes()
        logger.info('Completed Adarounding Model')
        return quant_sim.model
    @classmethod
    def _adaround_model(cls, model: onnx_pb.ModelProto, quant_sim: QuantizationSimModel, module_act_func_pair: Dict,
                        params: AdaroundParameters, use_cuda: bool = True, device: int = 0, user_onnx_libs: List[str] = None):
        """
        Optimize weight rounding of every module (AdaroundSupportedModules) of model in sequential manner
        based on occurrence
        :param model: Original fp32 model from which quant_sim was created.
        :param quant_sim: QuantizationSimModel object to optimize weight rounding.
                          The activation quantizers are expected to have been disabled.
        :param module_act_func_pair: Dictionary of module to immediate following activation function
        :param params: Adaround parameters
        :param use_cuda: If we should use cuda
        :param device: CUDA device ID
        :param user_onnx_libs: List of paths to all compiled ONNX custom ops libraries
        """
        # pylint: disable=too-many-locals, protected-access
        num_iterations = params.num_iterations
        if num_iterations is None:
            lowest_weight_bw = 32
            for param_name in quant_sim.param_names:
                quantizer = quant_sim.qc_quantize_op_dict[param_name]
                if quantizer.enabled and quantizer.data_type == QuantizationDataType.int:
                    lowest_weight_bw = min(lowest_weight_bw, quantizer.bitwidth)
            # If the lowest wegith bitwidth is < 8, then set num_iterations to 15K by default
            if lowest_weight_bw < 8:
                num_iterations = 15000
            else:
                num_iterations = 10000
        with tempfile.TemporaryDirectory() as tmp_dir:
            # Cache model input data to temporary directory
            cached_dataset = utils.CachedDataset(params.data_loader, params.num_batches, tmp_dir)
            # Optimization Hyper parameters
            opt_params = AdaroundHyperParameters(num_iterations, params.reg_param, params.beta_range,
                                                 params.warm_start)
            param_to_tensor_quantizer_dict = Adaround._create_param_to_tensor_quantizer_dict(quant_sim)
            model_data = ModelData(model.model)
            quantized_layer_to_input_tensor_name = Adaround._get_quantized_layer_input_tensor_name(quant_sim)
            # AdaRound must be applied to modules in the order of occurrence
            modules = get_ordered_ops(model)
            for module in tqdm(modules):
                name = module.name
                if cls._is_supported_layer_type(model_data.module_to_info[name]):
                    # Get module's next following activation function
                    act_func = module_act_func_pair[name]
                    quantized_input_name = quantized_layer_to_input_tensor_name[name]
                    logger.info("Started Optimizing weight rounding of module: %s", name)
                    AdaroundOptimizer.adaround_module(model_data.module_to_info[name], quantized_input_name,
                                                      model, quant_sim.model, act_func,
                                                      cached_dataset, opt_params, param_to_tensor_quantizer_dict,
                                                      use_cuda, device, user_onnx_libs)
    @classmethod
    def _is_supported_layer_type(cls, module_info: ModuleInfo):
        if not module_info.type in AdaroundSupportedModules:
            return False
        if not "weight" in module_info.params:
            return False
        if module_info.type in ("Conv", "ConvTranspose"):
            # Only 2d conv/convtranspose is supported
            return len(module_info.params["weight"].shape) == 4
        return True
    @staticmethod
    def _compute_param_encodings(quant_sim: QuantizationSimModel, params: AdaroundParameters):
        """
        Compute encodings for parameters, needed for initializing Adaround quantizers
        :param quant_sim: Quant sim
        :param params: Adaround params
        """
        for op_name, qc_op in quant_sim.qc_quantize_op_dict.items():
            if op_name in quant_sim.activation_names:
                qc_op.enabled = False
            else:
                qc_op.op_mode = OpMode.oneShotQuantizeDequantize
        params.forward_fn(quant_sim.session, params.forward_pass_callback_args)
        for op_name, qc_op in quant_sim.qc_quantize_op_dict.items():
            if op_name in quant_sim.param_names:
                qc_op.compute_encodings()
                qc_op.op_mode = OpMode.quantizeDequantize
    @staticmethod
    def _create_param_to_tensor_quantizer_dict(quant_sim: QuantizationSimModel) -> Dict[str, AdaroundTensorQuantizer]:
        """
        Create Adaround tensor quantizers for weight tensor
        :param quant_sim: Quant sim
        :return: Dict of param name to AdaroundTensorQuantizer
        """
        param_to_tq_dict = {}
        for param_name in quant_sim.param_names:
            quantizer = quant_sim.qc_quantize_op_dict[param_name]
            ch_axis = -1
            if quantizer.quant_info.usePerChannelMode:
                ch_axis = quantizer.quant_info.channelAxis
            adaround_quantizer = AdaroundTensorQuantizer(quantizer.bitwidth, 'Adaptive', quantizer.quant_scheme,
                                                         quantizer.use_symmetric_encodings, quantizer.enabled, ch_axis)
            adaround_quantizer.use_strict_symmetric = quantizer.use_strict_symmetric
            adaround_quantizer.use_unsigned_symmetric = quantizer.use_unsigned_symmetric
            # Set the encodings and replace by Adaround tensor quantizer
            adaround_quantizer.encoding = quantizer.encodings
            param_to_tq_dict[param_name] = adaround_quantizer
        return param_to_tq_dict
    @classmethod
    def _export_encodings_to_json(cls, path: str, filename_prefix: str, quant_sim: QuantizationSimModel):
        """
        Save Adadrounded module's parameter encodings to JSON file
        :param path: path where to store param encodings
        :param filename_prefix: filename to store exported weight encodings in JSON format
        :param quant_sim: QunatSim that contains the model and Adaround tensor quantizers
        """
        # pylint: disable=protected-access
        param_encodings = quant_sim._get_encodings(quant_sim.param_names, quantsim.encoding_version)
        # export encodings to JSON file
        os.makedirs(os.path.abspath(path), exist_ok=True)
        encoding_file_path = os.path.join(path, filename_prefix + '.encodings')
        with open(encoding_file_path, 'w') as encoding_fp:
            json.dump(param_encodings, encoding_fp, sort_keys=True, indent=4)
    @staticmethod
    def _override_param_bitwidth(quant_sim: QuantizationSimModel,
                                 param_bw_override_list: List[Tuple[str, int]]):
        """
        For the QuantSim, for the list of modules in the param_bw_override_list,
        overrides the default parameter bitwidths with the provided bitwidth.
        :param quant_sim: The QuantSim that was created using a deepcopy of the original model.
        :param param_bw_override_list: List of Tuples. Each Tuple is a param name and the corresponding parameter bitwidth
                                       to be used for that param.
        """
        # For the params specified in the param_bw_override_list, set the weight quantizer bitwidth
        for (param_name, bw) in param_bw_override_list:
            quant_sim.qc_quantize_op_dict[param_name] = bw
    @classmethod
    def _exclude_modules(cls, quant_sim: QuantizationSimModel,
                         ignore_quant_ops_list: List[str]):
        """
        For the modules mentioned in the ignore_quant_ops_list, remove the corresponding quant wrappers from the
        quantSim and excludes modules from adaround optimization.
        :param model: The original model
        :param quant_sim: The QuantSim that was created using a deepcopy of the original model.
        :param ignore_quant_ops_list: The list of quantizers for which the Quantization wrappers are removed from the
                                      QuantSim object.
        """
    @staticmethod
    def _get_quantized_layer_input_tensor_name(sim):
        quantized_layer_to_input_tensor_name = {}
        for node in sim.model.model.graph.node:
            if node.op_type in AdaroundSupportedModules:
                quantized_layer_to_input_tensor_name[node.name] = node.input[0]
        return quantized_layer_to_input_tensor_name