Source code for aimet_tensorflow.keras.quantsim

# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2022-2023, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
""" Quantsim for Keras """

from dataclasses import dataclass
import json
import os
from typing import Union, Dict, Tuple, Optional, List

import tensorflow as tf
from aimet_common import libpymo

from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_common.utils import AimetLogger, save_json_yaml
from aimet_common.quantsim import encoding_version, extract_global_quantizer_args
from aimet_tensorflow.keras.connectedgraph import ConnectedGraph
from aimet_tensorflow.keras.graphsearchtuils import GraphSearchUtils
from aimet_tensorflow.keras.quant_sim.qc_quantize_wrapper import QcQuantizeWrapper, QuantizerSettings
from aimet_tensorflow.keras.quant_sim.qc_mha_wrapper import QcQuantizableMultiHeadAttention
from aimet_tensorflow.keras.quant_sim.tensor_quantizer import TensorQuantizer, ActivationTensorQuantizer, \
    ParamPerTensorQuantizer, StaticGridPerChannelQuantizer, ParamPerChannelQuantizer
from aimet_tensorflow.keras.quantsim_config.quantsim_config import QuantSimConfigurator, INPUT_QUANTIZERS, \
    OUTPUT_QUANTIZERS, PARAM_QUANTIZERS
from aimet_tensorflow.keras.utils.common import convert_h5_model_to_pb_model

from aimet_tensorflow.keras.defs import AxisHandling
import aimet_tensorflow.keras.utils.common as keras_common_utils

_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)

unquantizable_modules = (tf.keras.layers.InputLayer, QcQuantizeWrapper)
substitutable_modules = {
    tf.keras.layers.MultiHeadAttention: QcQuantizableMultiHeadAttention
}

@dataclass
class QuantizationSimModelParams:
    """
    Data class that holds parameters for QuantizationSimModel. Used specifically to rebuild after converting to TF frozen pb
    """
    quant_scheme: Union[QuantScheme, str] = 'tf_enhanced'
    rounding_mode: str = 'nearest'
    default_output_bw: int = 8
    default_param_bw: int = 8
    in_place: bool = False
    config_file: str = None
    default_data_type: QuantizationDataType = QuantizationDataType.int

# pylint: disable=too-many-ancestors
# pylint: disable=too-many-instance-attributes
[docs]class QuantizationSimModel(tf.keras.Model): """ Implements mechanism to add quantization simulations ops to a model. This allows for off-target simulation of inference accuracy. Also allows the model to be fine-tuned to counter the effects of quantization. """ # pylint: disable=too-many-arguments # pylint: disable=unused-argument def __init__(self, model, quant_scheme: Union[QuantScheme, str] = 'tf_enhanced', rounding_mode: str = 'nearest', default_output_bw: int = 8, default_param_bw: int = 8, in_place: bool = False, config_file: str = None, default_data_type: QuantizationDataType = QuantizationDataType.int): """ :param model: Model to quantize :param quant_scheme: Quantization Scheme, currently supported schemes are post_training_tf and post_training_tf_enhanced, defaults to post_training_tf_enhanced :param rounding_mode: The round scheme to used. One of: 'nearest' or 'stochastic', defaults to 'nearest'. :param default_output_bw: bitwidth to use for activation tensors, defaults to 8 :param default_param_bw: bitwidth to use for parameter tensors, defaults to 8 :param in_place: If True, then the given 'model' is modified in-place to add quant-sim nodes. Only suggested use of this option is when the user wants to avoid creating a copy of the model :param config_file: Path to a config file to use to specify rules for placing quant ops in the model :param default_data_type: Default data type to use for quantizing all layer parameters. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 and default_param_bw=16 """ super(QuantizationSimModel, self).__init__() self._model_without_wrappers = model if not in_place: self._model_without_wrappers = tf.keras.models.clone_model(model) n_weights = len(self._model_without_wrappers.weights) self._model_without_wrappers.set_weights(model.get_weights()[:n_weights]) self._layer_name_to_quant_wrapper = {} self._substituted_layer = {} # to hold the substituted layers self._validate_model() self.connected_graph = ConnectedGraph(self._model_without_wrappers) self._quantsim_configurator = self._initialize_quantsim_configurator(quant_scheme, rounding_mode, default_output_bw, default_param_bw, default_data_type, config_file) self.quant_scheme = quant_scheme self._percentile_value = 100 # default percentile value self.per_channel_quantization_enabled = self._quantsim_configurator.per_channel_quantization_flag self.model = self._add_quantization_wrappers(quant_scheme, rounding_mode, default_output_bw, default_param_bw, default_data_type) self.quant_args = extract_global_quantizer_args(quant_scheme, self._quantsim_configurator) self._params = QuantizationSimModelParams(quant_scheme, rounding_mode, default_output_bw, default_param_bw, in_place, config_file, default_data_type) def _validate_model(self): """ Check that model is appropriate for quantsim. """ multiple_inbound_node_layers = [] for layer in self._model_without_wrappers.layers: if len(layer.inbound_nodes) > 1: multiple_inbound_node_layers.append(layer.name) if multiple_inbound_node_layers: error_msg = (f'Layers with more than one inbound nodes are unsupported. This may occur if a layer is ' f'reused multiple times in the model definition.\n' f'Layers with multiple inbound nodes: {multiple_inbound_node_layers}') _logger.error(error_msg) raise NotImplementedError(error_msg) def _get_quantizer_list(self) -> Tuple[List, List, List]: """ Method to provide a list of input, output and parameter quantizers :return: Three lists containing input, paramater and output quantizers respectively """ input_quantizers = [] parameter_quantizers = [] output_quantizers = [] for wrapper in self.quant_wrappers(): for quantizer in wrapper.input_quantizers: input_quantizers.append(quantizer) for quantizer in wrapper.param_quantizers: parameter_quantizers.append(quantizer) for quantizer in wrapper.output_quantizers: output_quantizers.append(quantizer) return input_quantizers, parameter_quantizers, output_quantizers def set_percentile_value(self, percentile_value: float): """ Set the percentile value to be used while computing encodings for quantizers having percentile quant scheme. :param percentile_value: Percentile value to be set to """ if percentile_value < 90 or percentile_value > 100: raise ValueError("Percentile value must be in range [90, 100]") self._percentile_value = percentile_value # Set the percentile value to the activation quantizers input_quantizers, _, output_quantizers = self._get_quantizer_list() for quantizer in input_quantizers + output_quantizers: if quantizer.quant_scheme == QuantScheme.post_training_percentile: quantizer.set_percentile_value(self._percentile_value) def _initialize_quantsim_configurator(self, quant_scheme: Union[QuantScheme, str], rounding_mode: str, default_output_bw: int, default_param_bw: int, default_data_type: QuantizationDataType = QuantizationDataType.int, config_file: str = None) -> QuantSimConfigurator: """ Initialize quantsim configurator :param quant_scheme: Quantization Scheme :param rounding_mode: The round scheme to used :param default_output_bw: bitwidth to use for activation tensors :param default_param_bw: bitwidth to use for parameter tensors :param default_data_type: data type to use for the parameter tensors :param config_file: Path to a config file to use to specify rules for placing quant ops in the model :return: QuantSimConfigurator """ return QuantSimConfigurator(self.connected_graph, quant_scheme, rounding_mode, default_output_bw, default_param_bw, default_data_type, config_file) def _add_quantization_wrappers(self, quant_scheme, rounding_mode, default_output_bw, default_param_bw, default_data_type): """ Add quantization wrappers to the model and return a new model with the wrappers inserted. :param quant_scheme: Quantization scheme to use :param rounding_mode: Rounding mode to use :param default_output_bw: Default bitwidth for activation quantizers :param default_param_bw: Default bitwidth for param quantizers :param default_data_type: data type to use for param quantizers """ def wrap_layer(layer) -> tf.keras.layers.Layer: """ Function to wrap layers with QcQuantizeWrappers, used by keras clone_model() :param layer: Layer to wrap :return: Wrapped layer, or original layer if layer is not to be wrapped """ activation_quant_settings = QuantizerSettings(default_output_bw, default_data_type, rounding_mode, quant_scheme, False, False, False) param_quant_settings = QuantizerSettings(default_param_bw, default_data_type, rounding_mode, quant_scheme, False, False, False) if isinstance(layer, tuple(substitutable_modules.keys())): new_class = substitutable_modules[type(layer)] config = layer.get_config() config["copy_source_weights"] = layer.get_weights() wrapped_layer = new_class.from_config(config) self._substituted_layer[layer] = wrapped_layer return wrapped_layer if isinstance(layer, tf.keras.Sequential): return tf.keras.models.clone_model(layer, clone_function=wrap_layer) if isinstance(layer, unquantizable_modules) or layer.submodules: return layer input_quantizers, output_quantizers, param_quantizers = self._get_quantizers_by_layer(layer) wrapper = QcQuantizeWrapper(layer, activation_quant_settings, param_quant_settings, num_inputs=len(layer.inbound_nodes[0].keras_inputs), input_quantizers=input_quantizers, output_quantizers=output_quantizers, param_quantizers=param_quantizers, per_channel_quantization_enabled=self.per_channel_quantization_enabled) self._layer_name_to_quant_wrapper[layer.name] = wrapper return wrapper return tf.keras.models.clone_model(self._model_without_wrappers, clone_function=wrap_layer) def _get_quantizers_by_layer(self, layer: tf.keras.layers.Layer) -> Tuple[Optional[ActivationTensorQuantizer], Optional[ActivationTensorQuantizer], Union[ParamPerTensorQuantizer, ParamPerChannelQuantizer]]: """ Get input/output/param quantizers from quantizers dictionary or initialize quantizers if layer is not found :param layer: Target layer :return: tuple of input, output, param quantizers """ quantizers_dict = self._quantsim_configurator.get_quantizers_dict(layer) if quantizers_dict is None: _logger.warning("%s not found in quantizers dict, will generate quantizers automatically", layer.name) input_quantizers = None output_quantizers = None param_quantizers = None else: input_quantizers = quantizers_dict.get(INPUT_QUANTIZERS) output_quantizers = quantizers_dict.get(OUTPUT_QUANTIZERS) param_quantizers = quantizers_dict.get(PARAM_QUANTIZERS) return input_quantizers, output_quantizers, param_quantizers @staticmethod def _quantizer_to_name_tuple(quantizers: List[TensorQuantizer]) -> Tuple[Optional[List[str]]]: """ Converts a list of quantizers to a tuple of quantizer names :param quantizers: quantizers :return: tuple of quantizer names """ quant_list = [] if not quantizers: return None for quantizer in quantizers: quant_list.append(quantizer.name) return tuple(quant_list) def get_quantizer_name_by_layer(self, layer: tf.keras.layers.Layer) -> Tuple[Optional[List[str]], Optional[List[str]], Optional[List[str]]]: """ Get the names of input, output and param quantizers :param layer: the keras layer :return: Tuple of quantizer names """ input_quantizers, output_quantizers, param_quantizers = self._get_quantizers_by_layer(layer) output_quantizers_names = self._quantizer_to_name_tuple(output_quantizers) input_quantizers_names = self._quantizer_to_name_tuple(input_quantizers) parameter_quantizers_names = self._quantizer_to_name_tuple(param_quantizers) return input_quantizers_names, output_quantizers_names, parameter_quantizers_names def _disable_quantizers_in_folded_batchnorm(self): """ Disable input/output/param quantizers if layer is folded batch normalization """ for quantsim_wrapper in self._layer_name_to_quant_wrapper.values(): if GraphSearchUtils.is_folded_batch_normalization(quantsim_wrapper.original_layer): for q in quantsim_wrapper.input_quantizers: q.disable() for q in quantsim_wrapper.output_quantizers: q.disable() for q in quantsim_wrapper.param_quantizers: q.disable() @staticmethod def _get_encoding_dict_for_quantizer(quantizer: TensorQuantizer) -> Union[List[Dict[str, Union[str, int, float]]], Dict[str, Union[str, int, float]]]: """ Get encoding dict for a tensor quantizer. :param quantizer: Quantizer to get encoding info from :return: Dictionary or List of dictionaries containing encodings info for the tensor quantizer """ if not isinstance(quantizer, ParamPerChannelQuantizer) or quantizer.data_type == QuantizationDataType.float: quantizer_encodings = [quantizer.encoding] else: quantizer_encodings = quantizer.encoding return [ { 'min': encoding.min, 'max': encoding.max, 'scale': encoding.delta, 'offset': int(encoding.offset), 'bitwidth': encoding.bw, 'is_symmetric': str(quantizer.is_symmetric), 'dtype': 'int' } if quantizer.data_type == QuantizationDataType.int else {'dtype': 'float', 'bitwidth': int(quantizer.bitwidth)} for encoding in quantizer_encodings ] def get_encodings_dict(self) -> Dict[str, Union[str, Dict]]: """ Get encodings dict containing all activation and parameter encodings info in the model :return: Dictionary containing all activation and parameter encodings info in the model """ # pylint: disable=protected-access activation_encodings = {} param_encodings = {} for wrapper in self.quant_wrappers(): for idx, input_quantizer in enumerate(wrapper.input_quantizers): if input_quantizer.is_encoding_valid() or input_quantizer.data_type == QuantizationDataType.float: # because dense layers in quantizable MHA are not explicitly sublayers, they don't have their # inbound_nodes parameter populated, so the name of the quantizer is used instead if not wrapper._layer_to_wrap.inbound_nodes: tensor_name = "multi_head_attention/" + wrapper.name + "/" + input_quantizer.name else: tensor_name = wrapper._layer_to_wrap.inbound_nodes[0].keras_inputs[idx].name encoding_dict = self._get_encoding_dict_for_quantizer(input_quantizer) activation_encodings[tensor_name] = encoding_dict for idx, param_quantizer in enumerate(wrapper.param_quantizers): if param_quantizer.is_encoding_valid() or param_quantizer.data_type == QuantizationDataType.float: param_name = wrapper._layer_to_wrap.weights[idx].name encoding_dict = self._get_encoding_dict_for_quantizer(param_quantizer) param_encodings[param_name] = encoding_dict for idx, output_quantizer in enumerate(wrapper.output_quantizers): if output_quantizer.is_encoding_valid() or output_quantizer.data_type == QuantizationDataType.float: # because dense layers in quantizable MHA are not explicitly sublayers, they don't have their # inbound_nodes parameter populated, so the name of the quantizer is used instead if not wrapper._layer_to_wrap.inbound_nodes: tensor_name = "multi_head_attention/" + wrapper.name + "/" + output_quantizer.name else: tensor_name = wrapper._layer_to_wrap.output.name encoding_dict = self._get_encoding_dict_for_quantizer(output_quantizer) activation_encodings[tensor_name] = encoding_dict encodings_dict = {'version': encoding_version, 'activation_encodings': activation_encodings, 'param_encodings': param_encodings, 'quantizer_args': self.quant_args if hasattr(self, "quant_args") else {}} return encodings_dict
[docs] def compute_encodings(self, forward_pass_callback, forward_pass_callback_args): """ Computes encodings for all quantization sim nodes in the model. :param forward_pass_callback: A callback function that is expected to runs forward passes on a model. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. :param forward_pass_callback_args: These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex. """ ops_with_invalid_encodings = [] self._compute_and_set_parameter_encodings(ops_with_invalid_encodings) self._set_op_mode_parameters(libpymo.TensorQuantizerOpMode.quantizeDequantize) forward_pass_callback(self.model, forward_pass_callback_args) for quant_wrapper in self.quant_wrappers(): quant_wrapper.compute_encoding(ops_with_invalid_encodings) op_mode = self._param_op_mode_after_analysis(self.quant_scheme) self._set_op_mode_parameters(op_mode) if ops_with_invalid_encodings: _logger.info('The following quantizers did not have valid encodings and have been set to passThrough mode: ' '%s', ops_with_invalid_encodings) _logger.info('This can be due to the quantizers not having been evaluated during the forward pass in ' 'compute encodings. Evaluation is required to collect statistics needed to compute valid ' 'encodings.\n' 'As a result, the quantizers have been set to passThrough mode, meaning no quantization noise ' 'will be simulated for these ops if they are evaluated in the future.\n' 'If this is not desired, amend the forward pass to evaluate tensors which require these ops ' 'to be evaluated, and recompute encodings.')
def _set_op_mode_parameters(self, op_mode: libpymo.TensorQuantizerOpMode): """ Sets quant mode for parameters and if the encodings are invalid, then adds those wrappers to wrappers_with_invalid_encodings :param op_mode: Quant mode to set to """ for quantizer_info in self.quant_wrappers(): for param_quantizer in quantizer_info.param_quantizers: if param_quantizer.is_enabled(): param_quantizer.quant_mode = op_mode
[docs] def export(self, path, filename_prefix, custom_objects=None, convert_to_pb=True): """ This method exports out the quant-sim model so it is ready to be run on-target. Specifically, the following are saved 1. The sim-model is exported to a regular Keras model without any simulation ops 2. The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired) :param path: path where to store model pth and encodings :param filename_prefix: Prefix to use for filenames of the model pth and encodings files :param custom_objects: If there are custom objects to load, Keras needs a dict of them to map them """ model_path = os.path.join(path, filename_prefix) self._model_without_wrappers.save(model_path) self._model_without_wrappers.save(model_path + '.h5', save_format='h5') # Conversion of saved h5 model to pb model for consumption by SNPE/QNN try: if convert_to_pb: convert_h5_model_to_pb_model(f'{model_path}.h5', custom_objects=custom_objects) except ValueError: _logger.error("Could not convert h5 to frozen pb. " "Please call export() again with custom_objects defined.") raise finally: encodings_dict = self.get_encodings_dict() encoding_file_path = os.path.join(path, filename_prefix + '.encodings') save_json_yaml(encoding_file_path, encodings_dict)
def _compute_and_set_parameter_encodings(self, ops_with_invalid_encodings: List): # pylint: disable=too-many-nested-blocks for quantizer_wrapper in self.quant_wrappers(): for idx, param_quantizer in enumerate(quantizer_wrapper.param_quantizers): if param_quantizer.is_enabled() and param_quantizer.data_type == QuantizationDataType.int: # 0th input to our quant wrapper is the tensor being quantized weight_tensor = quantizer_wrapper.original_layer.get_weights()[idx] # Per-channel if isinstance(param_quantizer, StaticGridPerChannelQuantizer): for index, tensor_quantizer in enumerate(param_quantizer.tensor_quantizer): if param_quantizer.axis_handling == AxisHandling.LAST_TWO_AXES.value: last_two_axes_combined_shape = list(weight_tensor.shape[:-2]) + [-1] channel_slice = weight_tensor.reshape(*last_two_axes_combined_shape) channel_slice = channel_slice.take(index, channel_slice.ndim - 1) elif isinstance(quantizer_wrapper.original_layer, tf.keras.layers.Conv2DTranspose): if weight_tensor.ndim == 4: channel_slice = weight_tensor.take(index, weight_tensor.ndim - 2) else: # For bias in Transpose layers channel_slice = weight_tensor.take(index, weight_tensor.ndim - 1) else: channel_slice = weight_tensor.take(index, weight_tensor.ndim - 1) tensor_quantizer.updateStats(channel_slice, False) # Per-tensor else: tensor_quantizer = param_quantizer.tensor_quantizer tensor_quantizer.updateStats(weight_tensor, False) param_quantizer.compute_encoding(ops_with_invalid_encodings) def set_and_freeze_param_encodings(self, encoding_path: str): """ Set and freeze parameter encodings from encodings JSON file :param encoding_path: path from where to load parameter encodings file """ # Load parameter encodings file with open(encoding_path) as json_file: param_encodings = json.load(json_file) for quant_wrapper in self.quant_wrappers(): quant_wrapper.set_and_freeze_param_encoding(param_encodings) def load_encodings_to_sim(self, encoding_file_path: str): """ Loads the saved encodings to quant sim model :param encoding_file_path: path from where to load encodings file :return: """ # pylint: disable=protected-access, too-many-branches, too-many-locals, too-many-statements # Load encodings file with open(encoding_file_path) as json_file: encodings = json.load(json_file) param_encodings = encodings['param_encodings'] activation_encodings = encodings['activation_encodings'] for wrapper in self.quant_wrappers(): for idx, input_quantizer in enumerate(wrapper.input_quantizers): # because dense layers in quantizable MHA are not explicitly sublayers, they don't have their # inbound_nodes parameter populated, so the name of the quantizer is used instead if not wrapper._layer_to_wrap.inbound_nodes: tensor_name = "multi_head_attention/" + wrapper.name + "/" + input_quantizer.name else: tensor_name = wrapper._layer_to_wrap.inbound_nodes[0].keras_inputs[idx].name if tensor_name in activation_encodings: if not input_quantizer.is_enabled(): _logger.info("Not loading encodings for quantizer: %s as it is disabled", tensor_name) continue encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(activation_encodings[tensor_name][0]) input_quantizer.tensor_quantizer.isEncodingValid = True input_quantizer.set_quantizer_encodings(encoding.bw, is_symmetric, encoding, libpymo.TensorQuantizerOpMode.quantizeDequantize) _logger.info("Setting encodings for : %s", tensor_name) else: if input_quantizer.is_enabled(): input_quantizer.disable() _logger.info("Encoding for quantizer: %s is not present thus disabling it.", tensor_name) for idx, param_quantizer in enumerate(wrapper.param_quantizers): param_name = wrapper._layer_to_wrap.weights[idx].name if param_name in param_encodings: if not param_quantizer.is_enabled(): _logger.info("Not loading encodings for parameter: %s as quantizer is disabled", param_name) continue if isinstance(param_quantizer, StaticGridPerChannelQuantizer): encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(param_encodings[param_name]) for tensor_quantizer in param_quantizer.tensor_quantizer: tensor_quantizer.isEncodingValid = True bw = encoding[0].bw else: encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(param_encodings[param_name][0]) param_quantizer.tensor_quantizer.isEncodingValid = True bw = encoding.bw param_quantizer.set_quantizer_encodings(bw, is_symmetric, encoding, libpymo.TensorQuantizerOpMode.oneShotQuantizeDequantize) _logger.info("Setting encodings for : %s", param_name) else: if param_quantizer.is_enabled(): param_quantizer.disable() _logger.info("Encoding for parameter: %s not present thus disabling this quantizer.", param_name) # Loading encodings means that compute encodings was called. Therefore, these two lines set the correct # op mode for the correct quant scheme and if the quantization was per channel or not. op_mode = self._param_op_mode_after_analysis(self.quant_scheme) self._set_op_mode_parameters(op_mode) for idx, output_quantizer in enumerate(wrapper.output_quantizers): # because dense layers in quantizable MHA are not explicitly sublayers, they don't have their # inbound_nodes parameter populated, so the name of the quantizer is used instead if not wrapper._layer_to_wrap.inbound_nodes: tensor_name = "multi_head_attention/" + wrapper.name + "/" + output_quantizer.name else: tensor_name = wrapper._layer_to_wrap.output.name if tensor_name in activation_encodings: if not output_quantizer.is_enabled(): _logger.info("Not loading encodings for quantizer: %s as it is disabled", tensor_name) continue encoding, is_symmetric = keras_common_utils.create_encoding_from_dict(activation_encodings[tensor_name][0]) output_quantizer.tensor_quantizer.isEncodingValid = True output_quantizer.set_quantizer_encodings(encoding.bw, is_symmetric, encoding, libpymo.TensorQuantizerOpMode.quantizeDequantize) _logger.info("Setting encodings for : %s", tensor_name) else: if output_quantizer.is_enabled(): output_quantizer.disable() _logger.info("Encoding for quantizer: %s is not present thus disabling it.", tensor_name) def _param_op_mode_after_analysis(self, quant_scheme) -> libpymo.TensorQuantizerOpMode: """ Returns quant mode to use for parameters after encodings have been computed :param quant_scheme: Quantization scheme to use :return: Quant mode to use """ if quant_scheme in [QuantScheme.training_range_learning_with_tf_init, QuantScheme.training_range_learning_with_tf_enhanced_init] \ or self.per_channel_quantization_enabled: return libpymo.TensorQuantizerOpMode.quantizeDequantize return libpymo.TensorQuantizerOpMode.oneShotQuantizeDequantize def quant_wrappers(self): """ Generator for yielding all quantization wrappers """ for layer in self.model.layers: if isinstance(layer, QcQuantizeWrapper): yield layer if isinstance(layer, QcQuantizableMultiHeadAttention): yield from layer.quant_wrappers() # For Getting Quantizers from Sequantial Block if isinstance(layer, tf.keras.Sequential): yield from quant_wrappers_for_sequential_block(layer) def get_quant_wrapper_for_layer_name(self, layer_name: str) -> QcQuantizeWrapper: """ Return qc quant wrapper corresponding to a layer name :param layer_name: Layer name to get quantize wrapper for :return: Qc quant wrapper corresponding to a layer name """ return self._layer_name_to_quant_wrapper.get(layer_name) def _fill_missing_encoding_min_max_gradients(self, gradients: list): """ Computes the encoding min/max gradients and populates the gradients list :param gradients: gradients computed using GradientTape(gradients for encoding min/max will be `None`) """ def _find_weight_in_layer(weight_name: str, model_layer: tf.keras.layers.Layer): for weight in model_layer.weights: if weight.name.split(":")[0] == weight_name: return weight return None # Mapping used to get the gradients of weights(kernel, bias etc) weight_name_to_gradient = dict(zip([weight.name.split(":")[0] for weight in self.model.trainable_weights], gradients)) # Mapping used to get index of encoding min/max gradients (which would be `None`) and fill them weight_name_to_index = dict(zip([weight.name for weight in self.model.trainable_weights], range(len(self.model.trainable_weights)))) # Only process layers where 'param_quantizers' is defined (i.e. QcQuantizeWrapper layers) for layer in filter(lambda _layer: hasattr(_layer, 'param_quantizers'), self.model.layers): for param_quantizer in layer.param_quantizers: if param_quantizer.name in weight_name_to_gradient: # Value of weight associated with this param quantizer weight_tensor = _find_weight_in_layer(param_quantizer.name, layer.original_layer) # Gradients of the weights grad = weight_name_to_gradient[param_quantizer.name] # Using the weights and it's gradients, compute gradients for encoding min/max dloss_by_dmin, dloss_by_dmax = param_quantizer.get_gradients_for_encoding_min_max(weight_tensor, grad) enc_min_index = weight_name_to_index[param_quantizer.encoding_min.name] enc_max_index = weight_name_to_index[param_quantizer.encoding_max.name] gradients[enc_min_index] = dloss_by_dmin gradients[enc_max_index] = dloss_by_dmax # pylint: disable=useless-super-delegation def get_config(self): return super().get_config() def call(self, inputs, training=None, mask=None): return self.model.call(inputs, training, mask) def train_step(self, data): """ Custom training loop, equivalent to overriding `keras.Model.fit` function Reference: https://keras.io/guides/customizing_what_happens_in_fit/ Only relevant when using range-learning, otherwise equivalent to `keras.Model.fit` Param quantizers are disconnected in the op graph of the wrapped model Because of this, the gradients are not computed for encoding min/max(when range learning is enabled) This custom train_step function computes the missing gradients for encoding min/max of param quantizers """ x, y = data with tf.GradientTape() as tape: predictions = self(x, training=True) loss = self.compiled_loss(y, predictions) gradients = tape.gradient(loss, self.model.trainable_weights) # Manually compute missing gradients for encoding min/max when using range learning if self.quant_scheme in [QuantScheme.training_range_learning_with_tf_init, QuantScheme.training_range_learning_with_tf_enhanced_init]: self._fill_missing_encoding_min_max_gradients(gradients) gradients_to_apply = [(gradient, weight) for gradient, weight in zip(gradients, self.model.trainable_weights) if gradient is not None] self.optimizer.apply_gradients(gradients_to_apply) self.compiled_metrics.update_state(y, predictions) return {m.name: m.result() for m in self.metrics}
def quant_wrappers_for_sequential_block(seq_block: tf.keras.Sequential): """ Generator for yielding all quantization wrappers for a Sequantial Block """ for layer in seq_block.layers: if isinstance(layer, QcQuantizeWrapper): yield layer if isinstance(layer, QcQuantizableMultiHeadAttention): yield from layer.quant_wrappers() # in cases of nested Sequential Block if isinstance(layer, tf.keras.Sequential): yield from quant_wrappers_for_sequential_block(layer)