# -*- mode: python -*-# =============================================================================# @@-COPYRIGHT-START-@@## Copyright (c) 2022, Qualcomm Innovation Center, Inc. All rights reserved.## Redistribution and use in source and binary forms, with or without# modification, are permitted provided that the following conditions are met:## 1. Redistributions of source code must retain the above copyright notice,# this list of conditions and the following disclaimer.## 2. Redistributions in binary form must reproduce the above copyright notice,# this list of conditions and the following disclaimer in the documentation# and/or other materials provided with the distribution.## 3. Neither the name of the copyright holder nor the names of its contributors# may be used to endorse or promote products derived from this software# without specific prior written permission.## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE# POSSIBILITY OF SUCH DAMAGE.## SPDX-License-Identifier: BSD-3-Clause## @@-COPYRIGHT-END-@@# =============================================================================""" ONNX Code to fold batch-norm layers """fromtypingimportDict,List,Tupleimportcontextlibimportnumpyasnpimportonnxfromonnximportnumpy_helperfromonnxruntime.quantization.onnx_quantizerimportONNXModelfrompackagingimportversionfromaimet_common.bias_correctionimportConvBnPatternHandlerfromaimet_common.graph_pattern_matcherimportPatternTypefromaimet_common.graph_searcherimportGraphSearcherfromaimet_common.connected_graph.connectedgraph_utilsimportget_ordered_opsfromaimet_commonimport_libpymoaslibpymofromaimet_common.utilsimportAimetLoggerfromaimet_onnx.meta.connectedgraphimportConnectedGraphfromaimet_onnx.meta.connectedgraphimportWEIGHT_INDEX,BIAS_INDEX,RUNNING_MEAN_INDEX,RUNNING_VAR_INDEXfromaimet_onnx.meta.operationsimportOpfromaimet_onnx.utilsimportget_node_attribute,remove_node,transpose_tensor,ParamUtils,retrieve_constant_input# pylint: disable=no-name-in-module, ungrouped-importsifversion.parse(onnx.__version__)>=version.parse("1.14.0"):fromonnximportNodeProto,TensorProto,ModelProtoelse:fromonnx.onnx_pbimportNodeProto,TensorProto,ModelProtologger=AimetLogger.get_area_logger(AimetLogger.LogAreas.BatchNormFolding)ConvType=['Conv','ConvTranspose']LinearType=['Gemm','MatMul']BatchNormType=['BatchNormalization']classBNLayer:""" Captures beta and gamma parameter for BatchNorm layers to be used during High Bias absorption """def__init__(self):self.bn_layer=Noneself.gamma=Noneself.beta=Nonedef_find_conv_bn_pairs(connected_graph:ConnectedGraph)->Dict:""" Uses searcher to find preceding and next bn layers for a conv/linear layer :param connected_graph: ConnectedGraph object. :return: dictionary of conv/linear Op with associated bn op / activation info """# initialize all patterns to be matched and associated call back functionspatterns_with_callbacks=[]layer_select_handler=ConvBnPatternHandler()preceding_linear_op_types=['Flatten','Reshape']# Linear layer combinationsforlinear_opinLinearType:forpreceding_linear_op_typeinpreceding_linear_op_types:# BN -> Linearpatterns_with_callbacks.append(PatternType(pattern=['BatchNormalization',preceding_linear_op_type,linear_op],action=layer_select_handler))forop_typeinConvType+LinearType:patterns_with_callbacks.append(PatternType(pattern=['BatchNormalization',op_type],action=layer_select_handler))patterns_with_callbacks.append(PatternType(pattern=[op_type,'BatchNormalization'],action=layer_select_handler))# create graph searcher instance with connected graph and patterns to searchgraph_searcher=GraphSearcher(connected_graph,patterns_with_callbacks)# get all conv/linear and bn infograph_searcher.find_all_patterns_in_graph_apply_actions()convs_bn_activation_dict=layer_select_handler.get_conv_linear_bn_info_dict()returnconvs_bn_activation_dictdeffind_all_batch_norms_to_fold(connected_graph:ConnectedGraph,)->Tuple[List[Tuple[NodeProto,NodeProto]],List[Tuple[NodeProto,NodeProto]]]:""" Find all possible batch norm layers that can be folded. Returns a list of pairs such that (bn, layer) means bn will be forward-folded into layer and (layer, bn) means bn will be backward-folded into layer :param connected_graph: connected graph model to search :return: A list of (layer, bn) pairs and a list of (bn, layer) pairs, where `bn` can be folded into to `layer`. """conv_linear_bn_activation_info_dict=_find_conv_bn_pairs(connected_graph)model=connected_graph.model# To mark BN's already picked for backward foldingbn_picked_for_folding=set()ordered_conv_fc_nodes=get_ordered_conv_linears(connected_graph)conv_bn_pairs=[]# Backward fold is given priority over Forward foldfornodeinordered_conv_fc_nodes:# Filter out combinations that are not supportedifnodeinconv_linear_bn_activation_info_dict:bn_info=conv_linear_bn_activation_info_dict[node]ifbn_info.output_bnandbn_info.output_bnnotinbn_picked_for_folding:ifis_valid_bn_fold(node.get_module(),model,True):conv_bn_pairs.append((node.get_module(),bn_info.output_bn.get_module()))bn_picked_for_folding.add(bn_info.output_bn)else:logger.info('...... invalid combination to fold %s',[node.name,bn_info.output_bn.name])bn_conv_pairs=[]fornodeinordered_conv_fc_nodes:# Filter out combinations that are not supportedifnodeinconv_linear_bn_activation_info_dict:bn_info=conv_linear_bn_activation_info_dict[node]ifbn_info.input_bnandbn_info.input_bnnotinbn_picked_for_folding:ifis_valid_bn_fold(node.get_module(),model,False):bn_conv_pairs.append((bn_info.input_bn.get_module(),node.get_module()))bn_picked_for_folding.add(bn_info.input_bn)else:logger.info('...... invalid combination to fold %s',[bn_info.input_bn.name,node.name])returnconv_bn_pairs,bn_conv_pairsdefget_ordered_conv_linears(conn_graph:ConnectedGraph)->List[Op]:""" helper to select a list of candidate layers for BatchNorm folding :param conn_graph: connected graph to search :return: List of conv/linear layers """# get ordered operations list from the connected graphlist_of_ordered_ops=get_ordered_ops(conn_graph.starting_ops)# look for conv/linear layersordered_convs=[]foropinlist_of_ordered_ops:ifop.typeinConvType+LinearType:ordered_convs.append(op)returnordered_convsdefis_valid_bn_fold(conv_linear:NodeProto,model:ModelProto,fold_backward:bool)->bool:""" Determine if a given layer can successfully absorb a BatchNorm given the layer type and parameters :param conv_linear: The Conv/Linear layer to fold a BatchNorm into. :param model: The model to which the Conv/Linear layer belongs. :param fold_backward: True if BatchNorm comes after Conv/Linear layer :return: True if a BatchNorm layer can be folded without causing output error. """valid=Trueifconv_linear.op_typeinLinearType:# Check if this is actually a fully connected layer or a dynamic matmulw=retrieve_constant_input(conv_linear,model,WEIGHT_INDEX)[0]ifwisNone:valid=Falseifnotfold_backward:# Cannot fold BN -> Conv with padding. AIMET does not support forward folding to grouped or DW Convifconv_linear.op_type=='Conv':valid&=all(item==0foriteminget_node_attribute(conv_linear,"pads"))valid&=get_node_attribute(conv_linear,"group")==1# AIMET does not support forward folding to ConvTransposeelifconv_linear.op_type=='ConvTranspose':valid=Falseelse:# AIMET does not support backwards folding to grouped ConvTransposeifconv_linear.op_type=='ConvTranspose':valid&=get_node_attribute(conv_linear,"group")in(1,get_input_output_channels(conv_linear,model)[0])returnvalid
[docs]deffold_all_batch_norms_to_weight(model:ModelProto)->[List]:""" Fold all possible batch_norm layers in a model into the weight of the corresponding conv layers :param model: onnx Model to perform BN fold on :return: A list of pairs of layers [(Conv/Linear, BN layer that got folded)] """ifisinstance(model,ONNXModel):model=model.modelconnected_graph=ConnectedGraph(model)model=connected_graph.modelconv_bn_pairs,bn_conv_pairs=find_all_batch_norms_to_fold(connected_graph)conv_bns=[]bn_convs=[]forconv,bninconv_bn_pairs:bn_layer=_fold_to_weight(model,conv,bn,True)conv_bns.append((conv,bn_layer))remove_node(bn,model.graph)forbn,convinbn_conv_pairs:bn_layer=_fold_to_weight(model,conv,bn,False)bn_convs.append((conv,bn_layer))remove_node(bn,model.graph)_update_standalone_batchnorm_ops(model)returnconv_bns,bn_convs
def_fold_to_weight(model:ModelProto,conv_linear:NodeProto,bn:NodeProto,fold_backward:bool):""" Fold BatchNorm into the weight and bias of the given layer. :param model: onnx model to which the conv/bn pair belong :param conv_linear: Conv or linear layer to fold BN into. :param bn: BatchNorm to fold. :param fold_backward: True if the BatchNorm comes after the Conv """# Must convert MatMul layers to Gemm to allow biasifconv_linear.op_type=="MatMul":_matmul_to_gemm(conv_linear,model)weight=ParamUtils.get_param(model,conv_linear,WEIGHT_INDEX)bias=ParamUtils.get_param(model,conv_linear,BIAS_INDEX)groups=get_node_attribute(conv_linear,"group")# If layer doesn't have bias, create a bias initializer and add it to the model, then retrieve itifnotbias:bias_data=np.zeros(get_input_output_channels(conv_linear,model)[1])bias_name=conv_linear.name+".bias"bias=numpy_helper.from_array(bias_data.astype(np.float32),name=bias_name)model.graph.initializer.append(bias)conv_linear.input.append(bias_name)bias=ParamUtils.get_param(model,conv_linear,BIAS_INDEX)# Transpose weights to C, N, H, W from N, C, H, W since axis are flipped for transposed conv# However depthwise conv layers are always N, 1, H, W whether transposed-conv or not, so no need to transpose# if conv_linear.type == "ConvTranspose" and conv_linear groups == 1:ifconv_linear.op_type=="ConvTranspose"andgroups==1:weight=transpose_tensor(weight,(1,0,2,3))# Gemm layers may or may not need to have weights transposed depending on value of transB attributeelifconv_linear.op_typeinLinearTypeandnotget_node_attribute(conv_linear,"transB"):weight=transpose_tensor(weight,(1,0))channels=weight.dims[0]iffold_backwardelseweight.dims[1]bn_param=get_bn_params(model,bn,channels)bn_layer=copy_bn_params_to_bn_layer(bn,bn_param)_call_mo_batch_norm_fold(weight,bias,bn_param,fold_backward=fold_backward)# Transpose weight back to original configurationifconv_linear.op_type=="ConvTranspose"andgroups==1:weight=transpose_tensor(weight,(1,0,2,3))elifconv_linear.op_typeinLinearTypeandnotget_node_attribute(conv_linear,"transB"):weight=transpose_tensor(weight,(1,0))weight_param=ParamUtils.get_param(model,conv_linear,WEIGHT_INDEX)weight_param.raw_data=weight.raw_datareturnbn_layerdef_matmul_to_gemm(node:NodeProto,model:ModelProto):""" Convert MatMul node to Gemm and initialize bias to zeros :param node: MatMul node to convert to Gemm :param model: model to which the node belongs """assertnode.op_type=="MatMul"weight,transposed=retrieve_constant_input(node,model,WEIGHT_INDEX)iftransposed:node.input[WEIGHT_INDEX]=weight.namemodel.graph.initializer.remove(weight)weight=transpose_tensor(weight,(1,0))model.graph.initializer.append(weight)node.op_type="Gemm"node.name=node.name.replace("MatMul","Gemm")# Create bias vector for Gemm operationbias_name=node.name+".bias"bias_data=np.zeros(weight.dims[1])bias=numpy_helper.from_array(bias_data.astype(np.float32),name=bias_name)model.graph.initializer.append(bias)node.input.append(bias_name)def_call_mo_batch_norm_fold(weight:TensorProto,bias:TensorProto,bn_params:libpymo.BNParams,fold_backward:bool):""" Calls C++ batch norm folding API. :param weight: Weight or scale tensor to fold BN into. :param bias: Bias tensor to fold BN into. :param bn_params: Batch Norm layer :param fold_backward: True if BatchNorm comes after Conv/Linear layer """weight_tensor=libpymo.TensorParams()weight_tensor.data=numpy_helper.to_array(weight).reshape(-1)weight_tensor.shape=np.array(weight.dims)bias_tensor=libpymo.TensorParams()bias_tensor.data=numpy_helper.to_array(bias).reshape(-1)bias_tensor.shape=np.array(bias.dims)is_bias_valid=Truewith_expand_shape_to_4d(weight_tensor):_bias=libpymo.fold(bn_params,weight_tensor,bias_tensor,is_bias_valid,fold_backward)bias.raw_data=np.asarray(_bias,dtype=np.float32).tobytes()weight.raw_data=np.asarray(weight_tensor.data,dtype=np.float32).tobytes()defget_bn_params(model:ModelProto,bn:NodeProto,channels:int)->libpymo.BNParams:""" Returns the populated libpymo.BNParams object for the given BatchNormalization layer with parameters repeated if necessary. :param model: model to which the bn layer belongs :param bn: BatchNormalization layer to retrieve the parameters from :param channels: The effective number of channels the BatchNorm layer operates on (needed for Gemm layers) :return: libpymo.BNParams object for the input BatchNorm layer """bn_params=libpymo.BNParams()gamma=numpy_helper.to_array(ParamUtils.get_param(model,bn,WEIGHT_INDEX)).reshape(-1)# In the case of BatchNorm2d -> Flatten -> Gemm, must resize the BN parameters to the Gemm input feature lengthresize=channels/len(gamma)bn_params.gamma=np.repeat(gamma,resize)bn_params.beta=np.repeat(numpy_helper.to_array(ParamUtils.get_param(model,bn,BIAS_INDEX)).reshape(-1),resize)bn_params.runningMean=np.repeat(numpy_helper.to_array(ParamUtils.get_param(model,bn,RUNNING_MEAN_INDEX)).reshape(-1),resize)runningVar=numpy_helper.to_array(ParamUtils.get_param(model,bn,RUNNING_VAR_INDEX))epsilon=get_node_attribute(bn,"epsilon")ifepsilonisNone:epsilon=1e-5# Default onnx epsilon valuesigma=np.sqrt(runningVar+epsilon)bn_params.runningVar=np.repeat(sigma.reshape(-1),resize)returnbn_paramsdefcopy_bn_params_to_bn_layer(bn:NodeProto,bn_params:libpymo.BNParams)->BNLayer:""" Copies bn params to a BN layer which can be used later by High bias absorption :param bn: BN layer :param bn_params: libpymo.BNParams object for the BatchNorm layer :return BNLayer object """bn_layer=BNLayer()bn_layer.bn_layer=bnbn_layer.gamma=bn_params.gammabn_layer.beta=bn_params.betareturnbn_layer@contextlib.contextmanagerdef_expand_shape_to_4d(weight_tensor:libpymo.TensorParams):""" Expand the shape of the weight into 4d. """dims=len(weight_tensor.shape)ifdims>4:raiseRuntimeErrorifdims==4:yieldweight_tensorelse:orig_shape=weight_tensor.shape_4d_shape=np.append(orig_shape,[1for_inrange(4-dims)]).astype(int)try:weight_tensor.shape=_4d_shapeyieldweight_tensorfinally:weight_tensor.shape=orig_shapedefget_input_output_channels(node:NodeProto,model:ModelProto)->Tuple[int,int]:""" Find the input and output channels of a given layer. :param node: The node to find the input/output channels of :param model: The onnx model to which the layers belong :return: Tuple of (num channels in, num channels out) """weight=ParamUtils.get_param(model,node,WEIGHT_INDEX)groups=get_node_attribute(node,"group")# If group atttribute does not exist in the node,then default is 1ifnotgroups:groups=1ifnode.op_type=="Conv":num_in_channels=weight.dims[1]*groupsnum_out_channels=weight.dims[0]elifnode.op_type=="ConvTranspose":num_in_channels=weight.dims[0]num_out_channels=weight.dims[1]*groupselifnode.op_type=="Gemm":transB=get_node_attribute(node,"transB")iftransB==1:num_out_channels=weight.dims[0]num_in_channels=weight.dims[1]else:num_out_channels=weight.dims[1]num_in_channels=weight.dims[0]else:num_out_channels=Nonenum_in_channels=Nonereturnnum_in_channels,num_out_channels# pylint: disable=too-many-localsdef_update_standalone_batchnorm_ops(model:ModelProto):""" Update weight and bias of standalone batchnorm ops in the model. :param model: onnx Model for which batchnorm parameters are to be updated. """fornodeinmodel.graph.node:ifnode.op_typeinBatchNormType:# get parameter names and indicesweight_name,bias_name,running_mean_name,running_var_name=node.input[1:]init_w,init_b,init_rm,init_rv=[ParamUtils.get_param(model,node,idx)foridxinrange(1,5)]attr=[itemforiteminnode.attributeifitem.name=="epsilon"]ifnotattr:attr=onnx.helper.make_attribute("epsilon",1e-5)# Default epsilon valuenode.attribute.append(attr)else:attr=attr[0]epsilon=attr.ftensor_w=numpy_helper.to_array(init_w)tensor_b=numpy_helper.to_array(init_b)tensor_rm=numpy_helper.to_array(init_rm)tensor_rv=numpy_helper.to_array(init_rv)# update valuesinv_sigma=np.reciprocal(np.sqrt(tensor_rv+epsilon))tensor_w=tensor_w*inv_sigmatensor_b=tensor_b-tensor_rm*tensor_wtensor_rm=np.zeros(tensor_w.shape,tensor_w.dtype)tensor_rv=np.ones(tensor_w.shape,tensor_w.dtype)attr.f=0.init_w_=numpy_helper.from_array(tensor_w,weight_name)init_b_=numpy_helper.from_array(tensor_b,bias_name)init_rm_=numpy_helper.from_array(tensor_rm,running_mean_name)init_rv_=numpy_helper.from_array(tensor_rv,running_var_name)# update initializersinit_w.CopyFrom(init_w_)init_b.CopyFrom(init_b_)init_rm.CopyFrom(init_rm_)init_rv.CopyFrom(init_rv_)