# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2019, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
""" holds common code for bias correction """
from aimet_common.defs import ActivationType
from aimet_common.utils import AimetLogger
from aimet_common.connected_graph.operation import Op
logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)
CONV_OP_TYPES = ['Conv1d', 'Conv2D', 'DepthwiseConv2dNative', 'Conv', 'ConvTranspose', 'Conv3d']
LINEAR_OP_TYPES = ['Dense', 'Gemm', 'MatMul']
BN_OP_TYPES = ['FusedBatchNormV3', 'FusedBatchNorm', 'BatchNormalization', 'BatchNorm3d']
[docs]class ConvBnInfoType:
"""
Type for hoding convs with bn info and activation types
Activation types supported are Relu and Relu6
"""
def __init__(self,
input_bn=None,
output_bn=None,
in_activation_type: ActivationType = ActivationType.no_activation,
out_activation_type: ActivationType = ActivationType.no_activation):
"""
:param input_bn: Reference to Input BatchNorm to layer
:param output_bn: Reference to Output BatchNorm to layer
:param in_activation_type: Type of Activation
:param out_activation_type: Type of Activation
"""
self.input_bn = input_bn
self.output_bn = output_bn
self.in_activation_type = in_activation_type
self.out_activation_type = out_activation_type
class ConvBnPatternHandler:
"""
common handler for matched patterns for bias correction and batchnorm fold.
"""
def __init__(self):
self.conv_linears_with_bn_dict = {}
def get_conv_linear_bn_info_dict(self):
"""
returns the dictionary created
:return: dictionary of convs/linears with bn and activation info
"""
return self.conv_linears_with_bn_dict
def __call__(self, *args, **kwargs):
"""
custom pattern match handler that keeps a dictionary of convs/linears with bn and activation info.
"""
_, op_subset = args
bn_activation_info = ConvBnInfoType()
activation_type = ActivationType.no_activation
conv_op = None
bn_op = None
for op in op_subset:
if op.type in CONV_OP_TYPES + LINEAR_OP_TYPES:
conv_op = op
op_key = get_op_dict_key(conv_op)
if op_key in self.conv_linears_with_bn_dict.keys():
bn_activation_info = self.conv_linears_with_bn_dict[op_key]
elif op.type in BN_OP_TYPES:
bn_op = op
elif op.type in ['Relu6', 'Clip']:
activation_type = ActivationType.relu6
elif op.type in ['Relu']:
activation_type = ActivationType.relu
if len(op_subset) >= 2:
if op_subset[0].type in BN_OP_TYPES:
bn_activation_info.input_bn = bn_op
bn_activation_info.in_activation_type = activation_type
# we do not match linear layers with preceding bn for bias correction
elif op_subset[0].type in CONV_OP_TYPES + LINEAR_OP_TYPES:
bn_activation_info.output_bn = bn_op
bn_activation_info.out_activation_type = activation_type
# in tf linear layer has two ops together [flatten/reshape -- dense] , check for len 3
elif len(op_subset) >= 3 and op_subset[1].type in ['Dense']:
bn_activation_info.output_bn = bn_op
bn_activation_info.out_activation_type = activation_type
op_key = get_op_dict_key(conv_op)
self.conv_linears_with_bn_dict[op_key] = bn_activation_info
def get_op_dict_key(op: Op):
"""
Returns the object to be used as a key in the conv/linear BN dict.
For torch and tensorflow models, returns op.get_module(). For onnx models, returns the original op.
:param op: connected graph layer to be used as a dictionary key
:return: object (op or op.get_module()) to be used as a key in the conv/linear BN dict
"""
module = op.get_module()
# ONNX NodeProto objects are not hashable, return the original Op object instead
if module.__hash__ is None:
return op
return module