Source code for aimet_tensorflow.keras.amp.quantizer_groups
# -*- mode: python -*-# =============================================================================# @@-COPYRIGHT-START-@@## Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.## Redistribution and use in source and binary forms, with or without# modification, are permitted provided that the following conditions are met:## 1. Redistributions of source code must retain the above copyright notice,# this list of conditions and the following disclaimer.## 2. Redistributions in binary form must reproduce the above copyright notice,# this list of conditions and the following disclaimer in the documentation# and/or other materials provided with the distribution.## 3. Neither the name of the copyright holder nor the names of its contributors# may be used to endorse or promote products derived from this software# without specific prior written permission.## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE# POSSIBILITY OF SUCH DAMAGE.## SPDX-License-Identifier: BSD-3-Clause## @@-COPYRIGHT-END-@@# =============================================================================""" Find quantizer groups in a model """importitertoolsfromtypingimportDict,List,Tuplefromcollectionsimportdefaultdictfromdataclassesimportdataclass,fieldimporttensorflowastffromaimet_common.connected_graph.operationimportOpfromaimet_common.utilsimportAimetLoggerfromaimet_common.amp.utilsimportCANDIDATE_WITH_DTYPEfromaimet_common.amp.quantizer_groupsimportQuantizerGroupBasefromaimet_tensorflow.keras.connectedgraphimportConnectedGraphfromaimet_tensorflow.keras.quantsimimportQuantizationSimModel,substitutable_modulesfromaimet_tensorflow.keras.quant_sim.tensor_quantizerimportTensorQuantizerlogger=AimetLogger.get_area_logger(AimetLogger.LogAreas.MixedPrecision)ops_to_skip=['view','NumToTensor','Split','PythonOp']ops_not_to_traverse=['size']INPUT_OPS_STR='input_ops'OUTPUT_OPS_STR='output_ops'
[docs]@dataclass(frozen=True)classQuantizerGroup(QuantizerGroupBase):""" Group of modules and quantizers """input_quantizers:Tuple[str,...]=field(default_factory=tuple)output_quantizers:Tuple[str,...]=field(default_factory=tuple)parameter_quantizers:Tuple[str,...]=field(default_factory=tuple)
[docs]defget_candidate(self,name_to_quantizer_dict:Dict)->CANDIDATE_WITH_DTYPE:""" Gets Activation & parameter bitwidth :param name_to_quantizer_dict: Gets module from module name :return: Tuple of Activation, parameter bitwidth and data type """activation_bw,parameter_bw=None,Noneactivation_data_type,parameter_data_type=None,Noneformodule_nameinself.input_quantizers:module=self.lookup_quantizer(module_name,name_to_quantizer_dict)forlayerinmodule.input_quantizers.layers:activation_bw=layer.bitwidthactivation_data_type=layer.data_typebreakbreakformodule_nameinself.output_quantizers:module=self.lookup_quantizer(module_name,name_to_quantizer_dict)forlayerinmodule.output_quantizers.layers:activation_bw=layer.bitwidthactivation_data_type=layer.data_typebreakbreakifself.parameter_quantizers:formodule_nameinself.parameter_quantizers:module=self.lookup_quantizer(module_name,name_to_quantizer_dict)forlayerinmodule.param_quantizers.layers:parameter_bw=layer.bitwidthparameter_data_type=layer.data_typebreakbreakreturn(activation_bw,activation_data_type),(parameter_bw,parameter_data_type)
[docs]@staticmethoddeflookup_quantizer(quantizer_name:str,name_to_quantizer_dict:Dict)->tf.keras.layers.Layer:""" Returns the quantizer layer corresponding to the name :quantizer_name: Name of the quantizer :name_to_quantizer_dict: Dictionary of mappings from quantizer name to quantizer layer """ifisinstance(quantizer_name,tuple):quantizer_name=quantizer_name[0]module=name_to_quantizer_dict[quantizer_name]returnmodule
[docs]defset_quantizers_to_candidate(self,name_to_quantizer_dict:Dict,candidate:CANDIDATE_WITH_DTYPE)->None:""" Sets a quantizer group to a given candidate bitwidth :param name_to_quantizer_dict: Gets module from module name :param candidate: candidate with act and param bw and data types """(activation_bw,activation_data_type),(param_bw,param_data_type)=candidateformodule_nameinself.input_quantizers:module=self.lookup_quantizer(module_name,name_to_quantizer_dict)forlayerinmodule.input_quantizers.layers:layer.bitwidth=activation_bwlayer.data_type=activation_data_typeformodule_nameinself.output_quantizers:module=self.lookup_quantizer(module_name,name_to_quantizer_dict)forlayerinmodule.output_quantizers.layers:layer.bitwidth=activation_bwlayer.data_type=activation_data_typeifnotself.parameter_quantizers:returnformodule_nameinself.parameter_quantizers:module=self.lookup_quantizer(module_name,name_to_quantizer_dict)forlayerinmodule.param_quantizers:layer.bitwidth=param_bwlayer.data_type=param_data_type
[docs]defto_list(self)->List[Tuple[str,str]]:""" Converts quantizer group to a list :return: List containing input/output quantizers & weight quantizers """ifself.parameter_quantizers:ret_list=list(itertools.chain((("input",module_name)formodule_nameinself.input_quantizers),(("output",module_name)formodule_nameinself.output_quantizers),(("weight",module_name)formodule_nameinself.parameter_quantizers),))else:ret_list=list(itertools.chain((("input",module_name)formodule_nameinself.input_quantizers),(("output",module_name)formodule_nameinself.output_quantizers),))returnret_list
[docs]defget_active_quantizers(self,name_to_quantizer_dict:Dict)->List[TensorQuantizer]:""" Find all active tensor quantizers associated with this quantizer group """quantizers=[]formodule_nameinself.input_quantizers:module=self.lookup_quantizer(module_name,name_to_quantizer_dict)quantizers+=list(module.input_quantizers.layers)formodule_nameinself.output_quantizers:module=self.lookup_quantizer(module_name,name_to_quantizer_dict)quantizers+=list(module.output_quantizers.layers)ifself.parameter_quantizers:formodule_nameinself.parameter_quantizers:module=self.lookup_quantizer(module_name,name_to_quantizer_dict)quantizers+=list(module.param_quantizers.layers)returnlist(set(quantizerforquantizerinquantizersifquantizer.is_enabled()))
[docs]defget_active_param_quantizers(self,name_to_quantizer_dict:Dict)->List[TensorQuantizer]:""" Find all active param tensor quantizers associated with this quantizer group :param name_to_quantizer_dict: Contains mapping of module name to sim.quantizer_config object """quantizers=[]ifself.parameter_quantizers:formodule_nameinself.parameter_quantizers:module=self.lookup_quantizer(module_name,name_to_quantizer_dict)quantizers+=list(module.param_quantizers.layers)returnlist(set(quantizerforquantizerinquantizersifquantizer.is_enabled()))
deffind_output_quantizer_groups(op:Op,parent_child_op_groups:Dict,map_for_skipped_ops:Dict):""" Finds quantizer groups along the parent to child flow :param op: pytorch module :param parent_child_op_groups: parent child relationships in graph :param map_for_skipped_ops: map to find first skipped parents of skipped ops """ifop.outputs:forconsumerinop.output_ops:name=op.nameifconsumer.typeinops_not_to_traverse:continueifop.dotted_nameinmap_for_skipped_ops:name=map_for_skipped_ops[op.name]ifconsumer.typeinops_to_skip:map_for_skipped_ops[consumer.name]=namefind_output_quantizer_groups(consumer,parent_child_op_groups,map_for_skipped_ops)# If there is a one to one connection between quantizerselse:ifnameinmap_for_skipped_ops:name=map_for_skipped_ops[name]parent_child_op_groups[name].append(consumer.name)else:ifop.dotted_nameinmap_for_skipped_ops:parent_child_op_groups[map_for_skipped_ops[op.dotted_name]]=[]deffind_op_groups(graph:ConnectedGraph)->Dict:""" Finds parent children relationship based on three rules 1) If there is a direct connection between two ops, op1 and op2, then op1 is parent of op2 and they form a group 2) If the input to an op (op1) is shared with another op (op2), the op producing the input (op0) is the parent, and op1 and op2 are the children :param graph: connected graph :return: Dict of parent (key) and children (value) relationship """parent_child_op_groups=defaultdict(list)map_for_skipped_ops={}foropingraph.ordered_ops:# Add 1st op as childifnotop.input_ops:parent_child_op_groups[INPUT_OPS_STR].append(op.name)# Add output op as child to put output of model as a quantizer groupifnotop.outputs:parent_child_op_groups[OUTPUT_OPS_STR].append(op.name)foropingraph.get_all_ops().values():ifop.typeinops_to_skiporop.typeinops_not_to_traverse:continuefind_output_quantizer_groups(op,parent_child_op_groups,map_for_skipped_ops)returnparent_child_op_groupsdefget_module_name_to_module_dict(sim:QuantizationSimModel)->Dict:""" Creates a dictionary of wrapped module's name to quantizer module :param sim: quantization sim :return: Dict key: name of wrapped module value: quantization wrapper """module_name_to_quantizer_dict={}forlayerinsim.quant_wrappers():forquantizerinlayer.input_quantizers:module_name_to_quantizer_dict[quantizer.name]=layerforquantizerinlayer.output_quantizers:module_name_to_quantizer_dict[quantizer.name]=layerforquantizerinlayer.param_quantizers:module_name_to_quantizer_dict[quantizer.name]=layerreturnmodule_name_to_quantizer_dict# pylint: disable-msg=too-many-locals# pylint: disable-msg=too-many-branchesdeffind_quantizer_group(sim:QuantizationSimModel)->Tuple[Dict,List[QuantizerGroup]]:""" Finds quantizer groups in a quantization sim model :param sim: Quantization sim :return: List of Quantizer groups """# Get connected graph from quantsim for model without wrappersconnected_graph=sim.connected_graphifconnected_graphisNone:raiseAssertionError('Aborting Auto Mixed Precision, connected graph needs to exist for Auto Mixed precision')quantizer_groups=[]parent_child_op_groups=find_op_groups(connected_graph)quantized_op_name_to_quantizer_dict=get_module_name_to_module_dict(sim)ifINPUT_OPS_STRinparent_child_op_groups:forchildinparent_child_op_groups[INPUT_OPS_STR]:# Add one quantizer group for each input and it's weight paramlayer=connected_graph.get_layer_from_op_name(child)ifisinstance(layer,tuple(substitutable_modules.keys())):sub_quantizer_groups=get_quantizers_groups_substituted_layer(sim,layer)quantizer_groups.extend(sub_quantizer_groups)continueinput_quantizer_names,output_quantizer_names,param_quantizer_names=sim.get_quantizer_name_by_layer(layer)ifinput_quantizer_namesorparam_quantizer_names:quantizer_group=QuantizerGroup(input_quantizers=input_quantizer_names,parameter_quantizers=param_quantizer_names)quantizer_groups.append(quantizer_group)logger.debug('\n Quantizer Group Added: %s',quantizer_group)# Based on which quantizers are enabled, create a list of quantizer_groupsforparentsinparent_child_op_groups:ifparentsin[INPUT_OPS_STR,OUTPUT_OPS_STR]:continueifnotisinstance(parents,tuple):parents=[parents]forparentinparents:layer=connected_graph.get_layer_from_op_name(parent)ifisinstance(layer,tuple(substitutable_modules.keys())):sub_quantizer_groups=get_quantizers_groups_substituted_layer(sim,layer)quantizer_groups.extend(sub_quantizer_groups)continueinput_quantizer_names,output_quantizer_names,param_quantizer_names=sim.get_quantizer_name_by_layer(layer)# Don't add quantizer group if it is emptyifinput_quantizer_namesoroutput_quantizer_namesorparam_quantizer_names:quantizer_group=QuantizerGroup(input_quantizers=input_quantizer_names,output_quantizers=output_quantizer_names,parameter_quantizers=param_quantizer_names)quantizer_groups.append(quantizer_group)logger.debug('\n Quantizer Group added: %s',quantizer_group)ifOUTPUT_OPS_STRinparent_child_op_groups:forparentinparent_child_op_groups[OUTPUT_OPS_STR]:# Add one quantizer group for each input and it's weight paramlayer=connected_graph.get_layer_from_op_name(parent)ifisinstance(layer,tuple(substitutable_modules.keys())):sub_quantizer_groups=get_quantizers_groups_substituted_layer(sim,layer)quantizer_groups.extend(sub_quantizer_groups)continueinput_quantizer_names,output_quantizer_names,param_quantizer_names=sim.get_quantizer_name_by_layer(layer)ifoutput_quantizer_names:quantizer_group=QuantizerGroup(input_quantizers=input_quantizer_names,output_quantizers=output_quantizer_names,parameter_quantizers=param_quantizer_names)quantizer_groups.append(quantizer_group)logger.debug('\n Quantizer Group added: %s',quantizer_group)returnquantized_op_name_to_quantizer_dict,quantizer_groups# pylint: disable=protected-accessdefget_quantizers_groups_substituted_layer(sim:QuantizationSimModel,layer)->List[QuantizerGroup]:""" Helper function to return the quantizer groups for the substituted layers """layer=sim._substituted_layer[layer]quantizer_groups=[]forquant_wrapperinlayer.quant_wrappers():input_quantizers=quant_wrapper.input_quantizersoutput_quantizers=quant_wrapper.output_quantizersparam_quantizers=quant_wrapper.param_quantizersinput_quantizer_names=QuantizationSimModel._quantizer_to_name_tuple(input_quantizers)output_quantizer_names=QuantizationSimModel._quantizer_to_name_tuple(output_quantizers)param_quantizer_names=QuantizationSimModel._quantizer_to_name_tuple(param_quantizers)ifinput_quantizer_namesoroutput_quantizer_namesorparam_quantizer_names:quantizer_group=QuantizerGroup(input_quantizers=input_quantizer_names,output_quantizers=output_quantizer_names,parameter_quantizers=param_quantizer_names)quantizer_groups.append(quantizer_group)logger.debug('\n Quantizer Group added: %s',quantizer_group)returnquantizer_groupsdeffind_wrapper_module(op_name:str,module_name_to_quantizer_dict:Dict)->Tuple[str,tf.keras.layers.Layer]:""" Finds quantization (wrapping) module corresponding to the wrapper module's dotted name :param op_name: Dotted name of op as represented in connected graph :param module_name_to_quantizer_dict: Dict key: name of wrapped module value: quantization wrapper :return: Module name and the corresponding quant-wrapper module in the sim """# pylint:disable = protected-accessmodule_name=op_name[op_name.find('.')+1:]ifmodule_nameinmodule_name_to_quantizer_dict:returnmodule_name,module_name_to_quantizer_dict[module_name]# Else it is a functional opraiseKeyError