Source code for aimet_tensorflow.keras.batch_norm_fold
# -*- mode: python -*-# =============================================================================# @@-COPYRIGHT-START-@@## Copyright (c) 2021-2023, Qualcomm Innovation Center, Inc. All rights reserved.## Redistribution and use in source and binary forms, with or without# modification, are permitted provided that the following conditions are met:## 1. Redistributions of source code must retain the above copyright notice,# this list of conditions and the following disclaimer.## 2. Redistributions in binary form must reproduce the above copyright notice,# this list of conditions and the following disclaimer in the documentation# and/or other materials provided with the distribution.## 3. Neither the name of the copyright holder nor the names of its contributors# may be used to endorse or promote products derived from this software# without specific prior written permission.## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE# POSSIBILITY OF SUCH DAMAGE.## SPDX-License-Identifier: BSD-3-Clause## @@-COPYRIGHT-END-@@# =============================================================================#pylint: disable=too-many-lines""" Utility for batch norm fold in tf 2.x """fromtypingimportIterable,Optional,Tuple,Union,List,Dict,Setimportnumpyasnpimporttensorflowastfimporttensorflow.keras.backendasKfrompackagingimportversion# pylint: disable=wrong-import-orderifversion.parse(tf.version.VERSION)>=version.parse("2.10"):# Ignore pylint errors as keras module is not available in TF 2.4fromkeras.layers.core.tf_op_layerimportTFOpLambda# pylint: disable=import-errorfromkeras.engine.functionalimportFunctional# pylint: disable=import-errorelse:# Ignore pylint errors due to conditional importsfromtensorflow.python.keras.engine.functionalimportFunctional# pylint: disable=ungrouped-importsfromtensorflow.python.keras.layers.coreimportTFOpLambda# pylint: disable=ungrouped-imports# pylint: disable=wrong-import-positionfromaimet_common.defsimportQuantScheme,MAP_ROUND_MODE_TO_PYMOimportaimet_common.libpymoaslibpymofromaimet_common.utilsimportAimetLoggerfromaimet_tensorflow.keras.model_preparerimport_KerasModelPreparerfromaimet_tensorflow.keras.quant_sim.qc_quantize_wrapperimportQcQuantizeWrapperfromaimet_tensorflow.keras.quant_sim.tensor_quantizerimportParamPerTensorQuantizerfromaimet_tensorflow.keras.quantsimimportQuantizationSimModelfromaimet_tensorflow.keras.utilsimportcommonfromaimet_tensorflow.keras.utils.model_connection_utilsimportModelLayerConnections,ModelLayerConnectionsPropertiesfromaimet_tensorflow.keras.utils.quantizer_utilsimportget_wrappers_bias_quantizer,get_wrappers_weight_quantizerfromaimet_tensorflow.keras.utils.weight_tensor_utilsimportWeightTensorUtils_logger=AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)LayerType=Union[tf.keras.layers.Conv2D,tf.keras.layers.Dense,tf.keras.layers.Conv2DTranspose,tf.keras.layers.DepthwiseConv2D]_supported_layers=LayerType.__args__PairType=Union[Tuple[LayerType,tf.keras.layers.BatchNormalization,bool],Tuple[tf.keras.layers.BatchNormalization,LayerType,bool]]BatchNormType=tf.keras.layers.BatchNormalization_supported_batchnorms=BatchNormType# Todo: search for more types of convolutionLinearType=tf.keras.layers.DenseConvType=Union[tf.keras.layers.Conv1D,tf.keras.layers.Conv2D,tf.keras.layers.DepthwiseConv2D,tf.keras.layers.Conv2DTranspose]_supported_convs=ConvType.__args__FlattenType=Union[tf.keras.layers.Flatten,tf.keras.layers.Reshape]MAP_PYMO_TO_ROUND_MODE={v:kfork,vinMAP_ROUND_MODE_TO_PYMO.items()}def_check_layer_to_find_pattern(cur_layer:tf.keras.layers.Layer,conv_linear_with_bn_dict:Dict[Union[ConvType,LinearType],List[Union[None,BatchNormType]]],layer_out_node_ref:Dict,has_seen:List[Union[None,ConvType,BatchNormType,FlattenType]]):""" find all paths in the model considering all inputs. :param cur_layer: layer to investigate for finding a pattern :param conv_linear_with_bn_dict: dictionary to store possible conv_bn pairs, key: Dense or Conv layer & Value: list of BNS; first index in this list shows bn_in and the second index shows bn_out :param layer_out_node_ref: dictionary includes layer_ref as a key, outbound nodes as value :param has_seen: for storing the layer which is useful for finding pattern in the next layers; index 0 is for conv op, index 2 is for bn op and index 3 is for storing flatten/reshape op """# pylint: disable=too-many-branchesifisinstance(cur_layer,_supported_convs):ifhas_seen[1]isnotNone:conv_linear_with_bn_dict[cur_layer]=[has_seen[1],None]has_seen[1]=Noneif(cur_layer.activationistf.keras.activations.linear)and \
(cur_layerinlayer_out_node_ref)andlen(layer_out_node_ref[cur_layer])==1:has_seen[0]=cur_layerelifisinstance(cur_layer,BatchNormType):ifhas_seen[0]isnotNone:ifhas_seen[0]inconv_linear_with_bn_dict:conv_linear_with_bn_dict[has_seen[0]][1]=cur_layerelse:conv_linear_with_bn_dict[has_seen[0]]=[None,cur_layer]has_seen[0]=Noneif(cur_layerinlayer_out_node_ref)andlen(layer_out_node_ref[cur_layer])==1:has_seen[1]=cur_layerelifisinstance(cur_layer,(tf.keras.layers.Flatten,tf.keras.layers.Reshape)):if(cur_layerinlayer_out_node_ref)andlen(layer_out_node_ref[cur_layer])==1:ifhas_seen[1]:has_seen[2]=cur_layerelse:has_seen[1]=Noneifhas_seen[0]:has_seen[0]=Noneelifisinstance(cur_layer,LinearType):ifhas_seen[1]isnotNoneandhas_seen[2]isnotNone:conv_linear_with_bn_dict[cur_layer]=[has_seen[1],None]has_seen[2]=Nonehas_seen[1]=Noneelse:has_seen[0]=Nonehas_seen[1]=Nonehas_seen[2]=Nonedef_add_children_layer_before_parent_layer(cur_layer:tf.keras.layers.Layer,node_layer_map:Dict,layer_out_node_map:Dict,visited_layers:Set[tf.keras.layers.Layer],reversed_ordered_layers:List[tf.keras.layers.Layer]):""" Function to use topological sorting for finding all the layers which are accessible from the specific input_layer in the opposite order of occurrence. :param cur_layer:layer that we want to find path from :param node_layer_map: dictionary includes node_ref as a key, in_layers and out_layer as value :param layer_out_node_map: dictionary includes layer_ref as a key, outbound nodes as value :param visited_layers: Set of all layers that have been visited :param reversed_ordered_layers: List of layers in the opposite order of occurrence for the layers that we have visited so far """# Mark the current layer as visited.visited_layers.add(cur_layer)ifcur_layerinlayer_out_node_map:# Recur for all the layers adjacent to this layerfornext_nodeinlayer_out_node_map[cur_layer]:next_layer=node_layer_map[next_node][1]ifnext_layernotinvisited_layers:_add_children_layer_before_parent_layer(next_layer,node_layer_map,layer_out_node_map,visited_layers,reversed_ordered_layers)reversed_ordered_layers.append(cur_layer)else:reversed_ordered_layers.append(cur_layer)def_get_ordered_layers(node_layer_map:Dict,layer_out_node_map:Dict)->List[tf.keras.layers.Layer]:""" Function to return the list with all the layers in which layers come before parent layer. :param node_layer_map: dictionary includes node_ref as a key, in_layers and out_layer as value :param layer_out_node_map: dictionary includes layer_ref as a key, outbound nodes as value :return: ordered_layers: List of all layers in the order of occurrence """# to find the input layers of the modelinput_layers=common.find_input_layers(node_layer_map)# Set of all layers that have been visited (to cut short duplicate traversals)visited_layers=set()# List of all layers in the opposite of order of occurrencereversed_ordered_layers=[]forinput_layerininput_layers:_add_children_layer_before_parent_layer(input_layer,node_layer_map,layer_out_node_map,visited_layers,reversed_ordered_layers)# reverse the list because layers are in reverse orderordered_layers=reversed_ordered_layers[::-1]# # filter ordered ops for only valid ops# ordered_ops = [op for op in ordered_ops if op in valid_ops]returnordered_layersdef_get_ordered_conv_linears(node_layer_map:Dict,layer_out_node_map:Dict)->List[Union[ConvType,LinearType]]:""" helper to select a list of conv_linears in the order of occurence :param node_layer_map: dictionary includes node_ref as a key, in_layers and out_layer as value :param layer_out_node_map: dictionary includes layer_ref as a key, outbound nodes as value :return: return List of conv/linear layer refs """# get ordered layers list in node_layer map dictionarylist_of_ordered_layers=_get_ordered_layers(node_layer_map,layer_out_node_map)# look for conv layersordered_conv_linears=[]forlayerinlist_of_ordered_layers:ifisinstance(layer,_supported_layers):ordered_conv_linears.append(layer)returnordered_conv_linearsdef_fill_conv_linear_bn_dict(cur_layer:tf.keras.layers.Layer,node_layer_ref:Dict,layer_out_node_ref:Dict,has_seen:List[Union[None,ConvType,BatchNormType,FlattenType]],visited_layer:Set[tf.keras.layers.Layer],conv_linear_with_bn_dict:Dict[Union[ConvType,LinearType],List[Union[None,BatchNormType]]]):""" fill conv_linear_bn_dict for the model :param cur_layer: dictionary includes node_ref as a key, in_layers and out_layer as value :param node_layer_ref: dictionary includes node_ref as a key, in_layers and out_layer as value :param layer_out_node_ref: dictionary includes layer_ref as a key, outbound nodes as value :paramm has_seen: for storing the layer which is useful for finding pattern in the next layers; index 0 is for conv op, index 2 is for bn op and index 3 is for storing flatten/reshape op :param visited_layer: to store all the layers that have been visited so far in the dictionary :param conv_linear_with_bn_dict: dictionary of all possible conv_bn pairs, key: Dense or Conv layer & Value: list of BNS; first index in this list shows bn_in and the second index shows bn_out """# Mark the current layer as visited to prevent passing from one layer more than oncevisited_layer.add(cur_layer)_check_layer_to_find_pattern(cur_layer,conv_linear_with_bn_dict,layer_out_node_ref,has_seen)ifcur_layerinlayer_out_node_ref:fornext_nodeinlayer_out_node_ref[cur_layer]:next_layer=node_layer_ref[next_node][1]ifnext_layernotinvisited_layer:_fill_conv_linear_bn_dict(next_layer,node_layer_ref,layer_out_node_ref,has_seen,visited_layer,conv_linear_with_bn_dict)else:has_seen[0]=Nonehas_seen[1]=Nonehas_seen[2]=Nonedef_find_possible_convs_linears_bn(node_layer_map:Dict,layer_out_node_map:Dict)\
->Dict[Union[ConvType,LinearType],List[Union[None,BatchNormType]]]:""" find all possible convs_linears_bn by traversing all paths in the model considering all inputs :param node_layer_map: dictionary includes node_ref as a key, in_layers and out_layer as value :param layer_out_node_map: dictionary includes layer_ref as a key, outbound nodes as value :return: return dictionary of all possible conv_bn pairs, key: Dense or Conv layer & Value: list of BNS; first index in this list shows bn_in and the second index shows bn_out """input_layers=common.find_input_layers(node_layer_map)visited_layer=set()conv_linear_with_bn_dict={}forinput_layerininput_layers:_fill_conv_linear_bn_dict(input_layer,node_layer_map,layer_out_node_map,[None,None,None],visited_layer,conv_linear_with_bn_dict)returnconv_linear_with_bn_dictdef_get_bn_params(bn:tf.keras.layers.BatchNormalization)->libpymo.BNParams():""" helper to populate BN params from given BN Layer, required for fold :param bn: BatchNorm Layer :return: return bn params in libpymo.TensorParams() format. """ifbn.gammaisNone:_logger.warning("Gamma for BatchNormalization '%s' is None. Setting to ones.",bn.name)# Batch Normalization layers can having missing gammas with two different cases. One is that the 'gamma' attribute# is set to None. The second is if `scale` is set to False upon creation of the layer which turns off gamma.withtf.name_scope(bn.name):weights_with_gamma_and_before_rebuild=[np.ones_like(bn.beta)]+bn.get_weights()bn.scale=Truebn.build(bn.input.shape)bn.set_weights(weights_with_gamma_and_before_rebuild)bn.gamma=next(filter(lambdaw:'gamma'inw.name,bn.weights))bn_params=libpymo.BNParams()bn_params.gamma=bn.gamma.numpy().reshape(-1)bn_params.beta=bn.beta.numpy().reshape(-1)bn_params.runningMean=bn.moving_mean.numpy().reshape(-1)bn_params.runningVar=bn.moving_variance.numpy().reshape(-1)epsilon=bn.epsilonvar=bn.moving_variance.numpy()var_with_epsilon=var+epsilonsigma=np.sqrt(var_with_epsilon)bn_params.runningVar=sigmareturnbn_paramsdef_get_bias_tensor(conv_linear:LayerType)->libpymo.TensorParams():""" Get bias tensor in given conv layer. Packs bias in the format required for BN fold (libpymo.TensorParams()). :param conv_linear: conv Layer :return: return bias param in libpymo.TensorParams() format. """bias_tensor=libpymo.TensorParams()ifconv_linear.biasisnotNone:bias_tensor.data=conv_linear.bias.numpy().reshape(-1)bias_tensor.shape=np.array(conv_linear.bias.shape)returnbias_tensordef_get_weight_tensor_transpose_reshape(conv_linear:LayerType)->libpymo.TensorParams():""" Get weight tensor from conv layer. Converts to right format - performs transpose and reshape. Packs it to the format required for BN fold (libpymo.TensorParams()). :param conv_linear: conv layer :return: return weight tensor in libpymo.TensorParams() format. """# Weight tensor libpymo formatweight_tensor=libpymo.TensorParams()# linear array to be sent for bn foldweight=conv_linear.get_weights()[0]shape=weight.shapeifisinstance(conv_linear,tf.keras.layers.DepthwiseConv2D):# Depthwise conv layers in TF have outputs(Noc) set to 1.# we will use format [Nic, Noc, kh, kw] -# to be compatible with cpp backend.weight=np.transpose(weight,(2,3,0,1))# [Nic, Noc, kh, kw]shape=np.array([shape[2],shape[3],shape[0],shape[1]])elifisinstance(conv_linear,tf.keras.layers.Dense):shape=np.concatenate((np.array([1,1]),shape))weight=np.transpose(weight,(1,0))# [Noc, Nic, kh, kw]shape=np.array([shape[3],shape[2],shape[0],shape[1]])elifisinstance(conv_linear,tf.keras.layers.Conv2DTranspose):weight=np.transpose(weight,(2,3,0,1))# [Noc, Nic, kh, kw]shape=np.array([shape[2],shape[3],shape[0],shape[1]])elifisinstance(conv_linear,tf.keras.layers.Conv2D):weight=np.transpose(weight,(3,2,0,1))# [Noc, Nic, kh, kw]shape=np.array([shape[3],shape[2],shape[0],shape[1]])else:_logger.error("_get_weight_tensor_transpose_reshape(): Operation type unsupported")weight_tensor.data=weight.reshape(-1)weight_tensor.shape=shapereturnweight_tensorclassPassThroughOp(tf.keras.layers.Layer):""" This is a pass-through op, used for purpose of making an op a no-op """# pylint: disable=arguments-differ@staticmethoddefcall(inputs):""" This is a function to return input as an output :param inputs: input to pass through """returninputs# pylint: disable=too-many-branches, protected-access, too-many-locals, too-many-nested-blocks@common.to_functionaldef_delete_bn_from_functional(model:tf.keras.Model,bn_layers_to_remove:List[tf.keras.layers.BatchNormalization])->tf.keras.Model:""" This function is used to remove ALL batch normalization layers from a functional model passed via the bn_layers_to_remove parameter. Removing in place is not possible for functional models as the layers inbound and outbound connections are immutable. This function returns a new model with the batch normalization layers removed. :param model: Model to remove bn_layers from :param bn_layers_to_remove: List of batch normalization layers to remove from the model :return: A new model with the batch normalization layers removed """# In order to do this, we first need to know the original models inbound and outbound connections to each layer.# We then need to create a new model with the same inbound and outbound connections, but with the batch normalization# layers removed. This is done by rerouting the inbound nodes of the batch normalization layers to the inbound nodes# of the next layer. This can be seen in the following diagram:## Original model flow -------------------------># ______________ ______________ ______________# | | | | | |# | Conv | -X-> | Batch Norm | -X-> | ReLU |# |_____________| |_____________| ^ |_____________|# New model flow \ /# \ /# \___________________/defwrapped_bn_layer_in_bns_to_remove(layer:tf.keras.layers.Layer)->bool:returnisinstance(layer,QcQuantizeWrapper)andlayer._layer_to_wrapinbn_layers_to_removetf.keras.backend.clear_session()# clear session to not have tensor name conflicts# Step 1: Get the inbound and outbound connections for each layer in the modelmodel_layer_connections=ModelLayerConnections.get_model_layers_connection_properties(model)forinpinmodel.inputs:model_layer_connections[ModelLayerConnectionsProperties.OUTPUT_TENSORS].update({inp.name:inp})# Step 2: Create a new model with the batch normalization layers removed by iterating through the layers in the model# and using the inbound and outbound connections to rerouting around the batch normalization layers.batch_norms_replaced_with_names={}model_outputs=[]forcurrent_layerinmodel.layers:ifisinstance(current_layer,tf.keras.layers.InputLayer):continue# Determine input tensors of the given layerlayer_input=[model_layer_connections[ModelLayerConnectionsProperties.OUTPUT_TENSORS][layer_aux]forlayer_auxinmodel_layer_connections[ModelLayerConnectionsProperties.INBOUND_NODES][current_layer.name]]layer_input=layer_input[0]iflen(layer_input)==1elselayer_input# Reroute around batch normalization layers if the layer is in the list of layers to removeifcurrent_layerinbn_layers_to_removeorwrapped_bn_layer_in_bns_to_remove(current_layer):_logger.debug("Removing Batch Normalization layer %s",current_layer.name)foroutbound_nodeincurrent_layer._outbound_nodes:# pylint: disable=protected-access# Find and replace the Batch Normalization output layers input that holds the Batch Normalization layer# node and replace it with the input layers of the Batch Normalization layer.# For example, if ReLU's inputs are [conv1_bn] and conv1_bn's inputs are [conv1], then we replace# ReLU's inputs with [conv1]all_batch_norms_inbound_layers_names= \
[inbound_node.inbound_layers.nameforinbound_nodeincurrent_layer._inbound_nodes]# Go through all the outbound layers of the batch normalization layer and replace the batch normalization# layer name with the input layer names of the batch normalization layer.batch_norms_outbound_layers_new_inbound_layers_names= \
[outlayer.replace(current_layer.name,*all_batch_norms_inbound_layers_names)foroutlayerinmodel_layer_connections[ModelLayerConnectionsProperties.INBOUND_NODES][outbound_node.outbound_layer.name]]# Keras Batch Norm only supports one input tensors. Meaning there is one singular layer coming into it.# Hence, 'inbound_nodes[0]'.batch_norms_replaced_with_names[current_layer.name]=current_layer._inbound_nodes[0].inbound_layers.namemodel_layer_connections[ModelLayerConnectionsProperties.INBOUND_NODES].update({outbound_node.outbound_layer.name:batch_norms_outbound_layers_new_inbound_layers_names})# The above updates our dict for the mapping of the inputs, but we need to also update what Keras thinks# the inputs are. This is done by updating the inbound nodes of the output layer of the Batch Normalization.# THIS IS ONLY FOR MAPPING THE INPUTS TO BUILD A NEW MODEL. The original models underlying structure is# not changed.outbound_node.outbound_layer._inbound_nodes=current_layer.inbound_nodes# pylint: disable=protected-access# Otherwise, treat like a normal layerelse:# For layers that have multiple inputs, order matters for what is fed into the layer. For example, if we have# an Add layer with inputs from a ReLU and a Batch Norm, the order they go into the Add matters. Furthermore,# if the Batch Norm is deleted, then it needs to be replaced with it's folded layer in the same order.KERAS_SYMBOLIC_TENSORS_INDEX=0# Check if we need to change layer_input order. If there is just one input, there is no order.# Special case when there is a Lambda layer with multiple inputs is handled seperatelyifisinstance(layer_input,List)andnotisinstance(current_layer,TFOpLambda):# Original models keras symbolic tensor orderoriginal_keras_symbolic_tensors_order=model_layer_connections[ModelLayerConnectionsProperties.CALL_ARGS][current_layer.name][KERAS_SYMBOLIC_TENSORS_INDEX]# Special case for Lambda layers. Lambda layers can be thought of as z = x + y. Unfortunately, their call# args for the keras symbolic tensors will ONLY have the x portion. In our layer_input we have both x and y.# This statement is added to wrap the x portion of the original call args and check if it's a batch norm# folded out.ifnotisinstance(original_keras_symbolic_tensors_order,List):original_keras_symbolic_tensors_order=[original_keras_symbolic_tensors_order]# Check if a Batch Norm that was deleted is in the original keras symbolic order.name_of_bn_replaced=[tensor._keras_history.layer.namefortensorinoriginal_keras_symbolic_tensors_orderiftensor._keras_history.layer.nameinbatch_norms_replaced_with_names]# If a Batch Norm is found, then the keras symbolic tensor order is slightly updated to replace the# Batch Norm with the folded layer. Otherwise, we can just use the original keras symbolic tensor order.ifname_of_bn_replaced:updated_keras_symbolic_tensors_order=[]forkeras_symbolic_tensorinoriginal_keras_symbolic_tensors_order:if(name_of_bn:=keras_symbolic_tensor._keras_history.layer.name)inname_of_bn_replaced:#pylint: disable=superfluous-parensupdated_keras_symbolic_tensors_order.append(model_layer_connections[ModelLayerConnectionsProperties.OUTPUT_TENSORS][batch_norms_replaced_with_names[name_of_bn]])else:updated_keras_symbolic_tensors_order.append(keras_symbolic_tensor)# Dictionary of the keras symbolic tensor name to the order.ordered_inputs={k.name:vforv,kinenumerate(updated_keras_symbolic_tensors_order)}# Sort layer_input based on the above dictionary.layer_input=sorted(layer_input,key=lambdacurrent_input,oi=ordered_inputs:oi[current_input.name])# Since we are rerouting around the batch normalization layers, we need to temporarily remove the inbound# and outbound nodes of the batch normalization layers so that the model can be built correctly and not# duplicate the non batch normalization layers inbound/outbound nodes.current_layer._inbound_nodes=[]# pylint: disable=protected-access# Special case for when there is a Lambda operation with multiple inputs. For example, z = x + y.ifisinstance(current_layer,TFOpLambda):kmp=_KerasModelPreparer.get_instance_for_common_layer_passthrough_functions(model_layer_connections)x=kmp._handle_normal_keras_layer(current_layer)# pylint: disable=protected-access# Updating the Model layer connectionskmp._update_output_tensors_in_model_layers_connections(# pylint: disable=protected-accesscurrent_layer,x,model)else:x=current_layer(layer_input)current_layer._outbound_nodes=[]# pylint: disable=protected-access# Set new output tensor (in this case, it will be the same as the original model)model_layer_connections[ModelLayerConnectionsProperties.OUTPUT_TENSORS].update({current_layer.name:x})# Save tensor in output list if it is output in the initial modelifcurrent_layer.nameinmodel.output_names:model_outputs.append(x)returntf.keras.Model(inputs=model.inputs,outputs=model_outputs)def_delete_bn_from_sequential(layer:tf.keras.layers.Layer,bn:tf.keras.layers.BatchNormalization):""" This is the function for removing batch normalization layers that are layers of sequential model :param layer: model to obtain bn_layer that we want to remove :param bn: batch normalization layer that needs to be removed """layers_after_bn=[]visited=Falseidx=None# pylint: disable=protected-accessforindex,inner_layerinenumerate(layer.layers):ifvisited:layers_after_bn.append(inner_layer)elifinner_layer==bn:visited=Trueidx=indexelifinner_layer.submodules:_delete_bn_for_non_subclassed_model(inner_layer,bn)ifvisitedandidxisnotNone:# pylint: disable=protected-accessfor_inrange(len(layer.layers)-idx):layer.pop()forlayer_to_addinlayers_after_bn:layer.add(layer_to_add)def_delete_bn_for_non_subclassed_model(model:Union[tf.keras.Model,tf.keras.layers.Layer],bn_layer:tf.keras.layers.BatchNormalization):""" Remove bn layer for those model which are not part of model subclassing :param model: model to delete bn layers from :param bn_layer: bn layer that should be removed """ifisinstance(model,tf.keras.Sequential):_delete_bn_from_sequential(model,bn_layer)# We are expecting to find sequential model in functional model# or model subclassing in the elif statementelifisinstance(model,(tf.keras.layers.Layer,tf.keras.Model)):forlayerinmodel.layers:iflayer.submodules:_delete_bn_for_non_subclassed_model(layer,bn_layer)def_delete_bn_from_model_subclassing(module_to_name_map:Dict[tf.keras.layers.Layer,Tuple[tf.keras.Model,str]],bn_layer:tf.keras.layers.BatchNormalization):""" Remove bn layer which is part of model subclassing api or model inheriting from tf.keras.layers.Layer :param module_to_name_map: model to remove bn from :param bn_layer: bn layer that should be removed """parent_ref,module_name=module_to_name_map[bn_layer]op=PassThroughOp()setattr(parent_ref,module_name,op)# pylint: disable=inconsistent-return-statementsdef_delete_all_bns_from_model(model:Union[tf.keras.Model,tf.keras.layers.Layer],bn_layers:List[tf.keras.layers.BatchNormalization])->Optional[tf.keras.Model]:""" Remove all bn layers for a given model. :param model: Model to have the bn layers removed from :param bn_layers: bn layers that should be removed :return: new model with bn layers removed, if model is functional else None """ifbn_layers:# QuantizationSimModel's model will fall into this case.ifisinstance(model,Functional)andnotisinstance(model,tf.keras.Sequential)orany(isinstance(l,QcQuantizeWrapper)forlinmodel.layers):return_delete_bn_from_functional(model,bn_layers)module_to_name_map=common.module_to_name_map(model)forbn_layerinbn_layers:ifbn_layerinmodule_to_name_map:_delete_bn_from_model_subclassing(module_to_name_map,bn_layer)else:_delete_bn_for_non_subclassed_model(model,bn_layer)def_find_all_batch_norms_to_fold(model:tf.keras.Model)->Tuple[List[PairType],List[PairType],Set[tf.keras.layers.BatchNormalization]]:""" uses searcher to choose layers for bias correction :param model: model to obtain conv_linear pairs for :return: List of conv/linear layers with associated bn op / activation info and a Set of all the batch norms which are marked for folding. """node_layer_map=common.create_node_to_layer_map(model)layer_out_node_map=common.create_layer_to_out_node_map(model)possible_convs_linears_bn=_find_possible_convs_linears_bn(node_layer_map,layer_out_node_map)# get all ordered convs/ linears layersordered_conv_linears=_get_ordered_conv_linears(node_layer_map,layer_out_node_map)bn_picked_for_folding=set()defget_pairs(conv_is_first=False)->List:index=1ifconv_is_firstelse0pairs_list=[]forconv_linearinordered_conv_linears:ifconv_linearinpossible_convs_linears_bnand(bn_info:=possible_convs_linears_bn[conv_linear]):ifbn_info[index]andbn_info[index]notinbn_picked_for_folding:pairs_list.append((conv_linear,bn_info[index])ifconv_is_firstelse(bn_info[index],conv_linear))bn_picked_for_folding.add(bn_info[index])returnpairs_listconv_bn_pairs=get_pairs(conv_is_first=True)bn_conv_pairs=get_pairs(conv_is_first=False)returnconv_bn_pairs,bn_conv_pairs,bn_picked_for_folding
[docs]deffold_all_batch_norms(model:tf.keras.Model) \
->Tuple[List[Tuple[LayerType,BatchNormType]],tf.keras.Model]:""" Fold all batch_norm layers in a model into corresponding conv/linear layers :param model: model to find all batch norms for :return: A tuple of List of conv/linear layers with associated bn op / activation info and a new model with the Batch Normalization layers folded """conv_bn_pairs,bn_conv_pairs,folded_bns=_find_all_batch_norms_to_fold(model)# Potential new model is returned in case the model is a functional modelpotential_new_model=_fold_given_batch_norms(model,conv_bn_pairs,bn_conv_pairs)model=potential_new_modelifpotential_new_modelelsemodel# Convert the standalone BNs which are not foldedbn_converted=convert_standalone_batchnorms(model,folded_bns)ifbn_converted:_logger.info("%d BatchNorms' weights got converted",len(bn_converted))model.compile()_logger.warning("A new model is returned with the Batch Normalization layers removed for Keras models. ""Please use this new model for the rest of the AIMET flow.")returnconv_bn_pairs+[(conv,bn)forbn,convinbn_conv_pairs],model
defconvert_standalone_batchnorms(model:tf.keras.Model,folded_bns:set)->List[tf.keras.layers.BatchNormalization]:""" Converts the weights of standalone batch norms remaining in the model after BN folding :param model: keras model on which batch norm folding is being performed :param folded_bns: list of batch norms which got folded :return: list of BatchNorms whose weights is converted """bn_converted=[]forlayerinmodel.layers:ifisinstance(layer,tf.keras.layers.BatchNormalization)andlayernotinfolded_bns:convert_batchnorm_parameters(layer)_logger.debug("%s weights got converted",layer.name)bn_converted.append(layer)returnbn_converteddefconvert_batchnorm_parameters(bn:tf.keras.layers.BatchNormalization):""" Convert the weights of BN such that it works as y = weights * x + bias :param bn: Batch Norm layer whose weights need to be converted """bn_params=_get_bn_params(bn)# inv :: 1/ Sqrt(var + eps)inv=tf.math.rsqrt(bn.moving_variance.numpy()+bn.epsilon)weight=np.array(bn_params.gamma)*np.array(inv)bias=np.array(bn_params.beta)-np.array(bn_params.runningMean)*weightnew_bn_weights=[weight.data,bias.data,np.zeros(shape=bn.moving_mean.shape,dtype=np.float32),np.ones(shape=bn.moving_variance.shape,dtype=np.float32)]bn.trainable=Falsebn.set_weights(new_bn_weights)bn.epsilon=0# pylint: disable=protected-access
[docs]deffold_all_batch_norms_to_scale(sim:QuantizationSimModel)->List[Tuple[QcQuantizeWrapper,QcQuantizeWrapper]]:""" Fold all batch_norm layers in a model into the quantization scale parameter of the corresponding conv layers :param sim: QuantizationSimModel to be folded :return: A list of pairs of layers [(Conv/Linear, BN layer that got folded)] """assertsim.modelisnotNone,"QuantizationSimModel attribute 'model' is None."model=sim._model_without_wrappersquant_wrappers={quant_wrapper._layer_to_wrap:quant_wrapperforquant_wrapperinsim.quant_wrappers()}conv_bn_pairs,bn_conv_pairs,_=_find_all_batch_norms_to_fold(model)conv_bn_pairs=[(quant_wrappers[conv],quant_wrappers[bn])forconv,bninconv_bn_pairs]bn_conv_pairs=[(quant_wrappers[bn],quant_wrappers[conv])forbn,convinbn_conv_pairs]old_model_without_wrappers=tf.keras.models.clone_model(model)conv_bn_pairs_without_wrappers,_,_=_find_all_batch_norms_to_fold(old_model_without_wrappers)old_model_without_wrappers.set_weights(WeightTensorUtils.get_all_sim_models_layer_to_wrap_weights(sim.model))# We fold both the sim.model and sim._model_without_wrappers because we rebuild the QuantizationSimModel during# export and this utilizes the sim._model_without_wrappers to achieve this.bn_fold_sim_model=_fold_given_batch_norms(sim.model,conv_bn_pairs,bn_conv_pairs)sim.model=bn_fold_sim_modelifbn_fold_sim_modelelsesim.modelbn_fold_model=_fold_given_batch_norms(old_model_without_wrappers,conv_bn_pairs_without_wrappers,[])sim._model_without_wrappers=bn_fold_modelifbn_fold_modelelseold_model_without_wrappersreturnconv_bn_pairs+[(conv,bn)forbn,convinbn_conv_pairs]
deffold_given_batch_norms(model:tf.keras.Model,layer_pairs:List[PairType])->Optional[tf.keras.Model]:""" Fold a given set of batch_norm layers into conv_linear layers :param model: Either a Keras Model or a QuantizationSimModel's model :param layer_pairs: Tuple of conv, bn layers and is_batch_norm_second flag :return: new model with batch norm layers folded if model is a functional model, else None """# pylint: disable=protected-accessconv_bn_paris=[]bn_conv_pairs=[]defis_batchnorm(layer:tf.keras.layers.Layer)->bool:ifisinstance(layer,QcQuantizeWrapper):layer=layer._layer_to_wrapreturnisinstance(layer,_supported_batchnorms)defis_conv_linear(layer:tf.keras.layers.Layer)->bool:ifisinstance(layer,QcQuantizeWrapper):layer=layer._layer_to_wrapreturnisinstance(layer,_supported_layers)forx,yinlayer_pairs:ifis_batchnorm(x):assertis_conv_linear(y)bn=xconv=ybn_conv_pairs.append((bn,conv))else:assertis_conv_linear(x)assertis_batchnorm(y)conv=xbn=yconv_bn_paris.append((conv,bn))return_fold_given_batch_norms(model,conv_bn_paris,bn_conv_pairs)def_fold_given_batch_norms(model:tf.keras.Model,conv_bn_pairs:Iterable[Tuple[tf.keras.layers.Layer,tf.keras.layers.Layer]],bn_conv_pairs:Iterable[Tuple[tf.keras.layers.Layer,tf.keras.layers.Layer]])-> \
Optional[tf.keras.Model]:""" Fold a given set of batch_norm layers into conv layers :param model: Model :param conv_bn_pairs: List of (conv, bn) pairs to fold :param bn_conv_pairs: List of (bn, conv) pairs to fold """forbn,convinbn_conv_pairs:ifisinstance(conv,QcQuantizeWrapper):raiseRuntimeError(f"Forward folding to scale is not possible. Got {conv.name}")bn_layers=[]def_fold(conv,bn,fold_backward):is_wrapped=isinstance(conv,QcQuantizeWrapper)orisinstance(bn,QcQuantizeWrapper)try:ifis_wrapped:assertisinstance(conv,QcQuantizeWrapper)andisinstance(bn,QcQuantizeWrapper)bn._layer_to_wrap.trainable=False_fold_to_scale(conv,bn)bn_layers.append(bn._layer_to_wrap)else:bn.trainable=False_fold_to_weight(conv,bn,fold_backward=fold_backward)except_BatchNormFoldingNotSupportedase:bn_name=bn._layer_to_wrap.nameifis_wrappedelsebn.nameconv_name=conv._layer_to_wrap.nameifis_wrappedelseconv.name_logger.warning("Failed to fold %s to %s. [Reason] %s",bn_name,conv_name,str(e))else:bn_layers.append(bn._layer_to_wrapifis_wrappedelsebn)forconv,bninconv_bn_pairs:_fold(conv,bn,fold_backward=True)forbn,convinbn_conv_pairs:_fold(conv,bn,fold_backward=False)return_delete_all_bns_from_model(model,bn_layers)class_BatchNormFoldingNotSupported(RuntimeError):passdef_fold_to_scale(conv_wrapper:QcQuantizeWrapper,bn_wrapper:QcQuantizeWrapper):""" Fold BatchNorm into the scale and bias of the given layer. :param conv_wrapper: QcQuantizeWrapper that wraps conv or linear layer :param bn_wrapper: QcQuantizeWrapper that wraps the Batch Norm layer """# pylint: disable=protected-access, too-many-statements, too-many-localsconv=conv_wrapper._layer_to_wrapbn=bn_wrapper._layer_to_wrapweight_quantizer=get_wrappers_weight_quantizer(conv_wrapper.param_quantizers)bias_quantizer=get_wrappers_bias_quantizer(conv_wrapper.param_quantizers)# Checking QuantScheme as aimet_tensorflow.keras does not have LearnedGridTensorQuantizerifweight_quantizer.quant_schemenotin[QuantScheme.training_range_learning_with_tf_init,QuantScheme.training_range_learning_with_tf_enhanced_init]:raise_BatchNormFoldingNotSupported("BatchNorm folding to scale supports training_range_learning_with_tf_init or ""training_range_learning_with_tf_enhanced_init only. "f"got {weight_quantizer.quant_scheme}")output_quantizer=conv_wrapper.output_quantizers[0]ifoutput_quantizer.is_enabled():raise_BatchNormFoldingNotSupported("BatchNorm should belong to the same supergroup with the layer to be folded to.")ifbias_quantizer:ifbias_quantizer.is_enabled():raise_BatchNormFoldingNotSupported("Can't fold BatchNorm to scale if bias quantizer is enabled.")enc_min=weight_quantizer._encoding_minenc_max=weight_quantizer._encoding_maxifnotweight_quantizer.is_encoding_valid():raiseRuntimeErrorwithbn_wrapper._quantize_params():_fold_to_weight(conv,bn,fold_backward=True)gamma=bn.gammasigma=K.sqrt(bn.moving_variance+bn.epsilon)fori,cinenumerate(gamma/sigma):c=float(c)ifc>=0:enc_max[i].assign(enc_max[i]*c)enc_min[i].assign(enc_min[i]*c)else:enc_max_before_reassign=enc_max[i]enc_max[i].assign(enc_min[i]*c)enc_min[i].assign(enc_max_before_reassign*c)# Copy batchnorm's output quantizers to conv output quantizersforconv_output_quantizer,bn_output_quantizerin \
zip(conv_wrapper.output_quantizers,bn_wrapper.output_quantizers):ifbn_output_quantizer.encodingisnotNone:conv_output_quantizer._encoding_min.assign(bn_output_quantizer._encoding_min)conv_output_quantizer._encoding_max.assign(bn_output_quantizer._encoding_max)conv_output_quantizer._is_encoding_valid=Truetensor_quantizers=conv_output_quantizer._tensor_quantizerifisinstance(conv_output_quantizer._tensor_quantizer,List)else[conv_output_quantizer._tensor_quantizer]fortensor_quantizerintensor_quantizers:tensor_quantizer.isEncodingValid=Trueifbn_output_quantizer.is_enabled():conv_output_quantizer.enable()else:conv_output_quantizer.disable()bn_output_quantizer.disable()ifbias_quantizerisNone:bias_quantizer=ParamPerTensorQuantizer(conv,conv.bias.name.split(':')[0],weight_quantizer.quant_scheme,MAP_PYMO_TO_ROUND_MODE[weight_quantizer.round_mode],weight_quantizer.bitwidth,weight_quantizer.data_type,weight_quantizer.is_symmetric,weight_quantizer.use_strict_symmetric,weight_quantizer.use_unsigned_symmetric,enabled=False)tensor_quantizers=bias_quantizer._tensor_quantizerifisinstance(bias_quantizer._tensor_quantizer,List)else[bias_quantizer._tensor_quantizer]fortensor_quantizerintensor_quantizers:tensor_quantizer.isEncodingValid=Trueconv_wrapper.param_quantizers.append(bias_quantizer)def_fold_to_weight(conv_linear:LayerType,bn:BatchNormType,fold_backward:bool):""" Fold BatchNorm into the weight and bias of the given layer. :param conv_linear: Conv or linear layer to fold BN into. :param bn: BatchNorm to fold. :param fold_backward: To fold backwards or not """is_bias_valid=conv_linear.biasisnotNonebn_params=_get_bn_params(bn)weight_tensor=_get_weight_tensor_transpose_reshape(conv_linear)bias_tensor=_get_bias_tensor(conv_linear)# Updated weight and biasbias=libpymo.fold(bn_params,weight_tensor,bias_tensor,is_bias_valid,fold_backward)ifisinstance(conv_linear,tf.keras.layers.DepthwiseConv2D):# Depthwise conv layers in TF have outputs(Noc) set to 1.# we send in format [Nic, Noc, kh, kw]numpy_weight_reshaped=np.reshape(weight_tensor.data,weight_tensor.shape).transpose((2,3,0,1))elifisinstance(conv_linear,tf.keras.layers.Dense):# o, i - convert to i , onumpy_weight_reshaped=np.reshape(weight_tensor.data,[weight_tensor.shape[0],weight_tensor.shape[1]]).transpose(1,0)elifisinstance(conv_linear,tf.keras.layers.Conv2DTranspose):# we sent in format [Noc, Nic, kh, kw]numpy_weight_reshaped=np.reshape(weight_tensor.data,weight_tensor.shape).transpose((2,3,0,1))else:# conv2D case# we sent in format [Noc, Nic, kh, kw]numpy_weight_reshaped=np.reshape(weight_tensor.data,weight_tensor.shape).transpose((2,3,1,0))# update bias tensor, even in case there was no existing bias add op in given conv2D op.bias_tensor_shape=[weight_tensor.shape[0]]numpy_bias_reshaped=np.reshape(bias,bias_tensor_shape)ifnotis_bias_valid:conv_linear.use_bias=Trueconv_linear.bias=conv_linear.add_weight(name=f"{conv_linear.name}/bias",shape=(weight_tensor.shape[0],),dtype=conv_linear.dtype,trainable=True)conv_linear.set_weights([numpy_weight_reshaped.data,numpy_bias_reshaped])