# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2019-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Optimization code to fold batch-norm layers """
from typing import List, Tuple, Union, Dict, Iterable, Set, Any
import numpy as np
import torch
import torch.nn
from torch.nn.modules.batchnorm import BatchNorm1d, BatchNorm2d
from torch.nn.modules.conv import _ConvTransposeNd
import aimet_common.libpymo as libpymo
from aimet_common.batch_norm_fold import batch_norm_fold, expand_shape_to_4d
from aimet_common.bias_correction import ConvBnPatternHandler, CONV_OP_TYPES, LINEAR_OP_TYPES, BN_OP_TYPES
from aimet_common.graph_pattern_matcher import PatternType
from aimet_common.graph_searcher import GraphSearcher
from aimet_common.utils import AimetLogger
# pylint: disable=unused-import
from aimet_torch.defs import PassThroughOp
from aimet_torch import utils
from aimet_torch.meta.connectedgraph import ConnectedGraph
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.qc_quantize_op import QcQuantizeWrapper
from aimet_torch.tensor_quantizer import LearnedGridTensorQuantizer
_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.BatchNormFolding)
LayerType = Union[
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.ConvTranspose2d,
]
_supported_layers = LayerType.__args__
BatchNormType = Union[BatchNorm1d, BatchNorm2d]
_supported_batchnorms = BatchNormType.__args__
# Temporary flag to flip underlying implementation. This flag will be removed in the future releases.
USE_PYTHON_IMPL = True
def _delete_bn_from_model(model: torch.nn.Module, bn_layer_list: Iterable[BatchNormType]):
utils.replace_modules_with_instances_of_new_type(model, bn_layer_list, torch.nn.Identity)
def _call_mo_batch_norm_fold(weight: torch.Tensor,
bias: torch.Tensor,
bn: BatchNormType,
fold_backward: bool):
"""
Calls C++ batch norm folding API.
:param weight: Weight or scale tensor to fold BN into.
:param bias: Bias tensor to fold BN into.
:param bn: Batch Norm layer
:param fold_backward: True if BatchNorm comes after Conv/Linear layer
"""
with torch.no_grad():
bn_params = libpymo.BNParams()
bn_params.gamma = bn.weight.detach().cpu().numpy().reshape(-1)
bn_params.beta = bn.bias.detach().cpu().numpy().reshape(-1)
bn_params.runningMean = bn.running_mean.detach().cpu().numpy().reshape(-1)
sigma = torch.sqrt(bn.running_var + bn.eps)
bn_params.runningVar = sigma.detach().cpu().numpy().reshape(-1)
weight_tensor = libpymo.TensorParams()
weight_tensor.data = weight.detach().cpu().numpy().reshape(-1)
weight_tensor.shape = np.array(weight.shape)
bias_tensor = libpymo.TensorParams()
bias_tensor.data = bias.detach().cpu().numpy().reshape(-1)
bias_tensor.shape = np.array(bias.shape)
is_bias_valid = True
_4d_shape = expand_shape_to_4d(weight_tensor.shape)
try:
orig_shape = weight_tensor.shape
weight_tensor.shape = _4d_shape
_bias = libpymo.fold(bn_params, weight_tensor, bias_tensor, is_bias_valid, fold_backward)
finally:
weight_tensor.shape = orig_shape
bias.copy_(torch.tensor(_bias, device=bias.device, dtype=bias.dtype)
.reshape_as(bias))
weight.copy_(torch.tensor(weight_tensor.data, device=weight.device, dtype=weight.dtype)
.reshape_as(weight))
def _call_py_batch_norm_fold(weight: torch.Tensor,
bias: torch.Tensor,
bn: Union[BatchNorm1d, BatchNorm2d],
fold_backward: bool):
"""
BN fold without calling C++ APIs.
:param weight: conv/linear weight
:param bias: conv/linear bias
:param bn: Batch Norm layer
:param fold_backward: True if BatchNorm comes after Conv/Linear layer
"""
with torch.no_grad():
gamma = bn.weight.detach().cpu().numpy()
beta = bn.bias.detach().cpu().numpy()
mu = bn.running_mean.detach().cpu().numpy()
sigma = torch.sqrt(bn.running_var + bn.eps).detach().cpu().numpy()
_weight = weight.detach().cpu().numpy()
_bias = bias.detach().cpu().numpy()
_4d_shape = expand_shape_to_4d(_weight.shape)
_weight, _bias = batch_norm_fold(_weight.reshape(_4d_shape), _bias, gamma, beta, mu, sigma, fold_backward)
bias.copy_(torch.from_numpy(_bias).reshape_as(bias)).to(device=bias.device, dtype=bias.dtype)
weight.copy_(torch.from_numpy(_weight).reshape_as(weight)).to(device=weight.device, dtype=weight.dtype)
class _BatchNormFoldingNotSupported(RuntimeError):
pass
def _fold_to_scale(conv_wrapper: QcQuantizeWrapper, bn_wrapper: QcQuantizeWrapper):
"""
Fold BatchNorm into the scale and bias of the given layer.
:param conv_wrapper: QcQuantizeWrapper that wraps conv or linear layer.
:param bn_wrapper: QcQuantizeWrapper that wraps bn.
"""
# pylint: disable=protected-access, too-many-locals, too-many-branches, too-many-statements
conv = conv_wrapper._module_to_wrap
bn = bn_wrapper._module_to_wrap
weight_quantizer = conv_wrapper.param_quantizers["weight"]
if not isinstance(weight_quantizer, LearnedGridTensorQuantizer):
raise _BatchNormFoldingNotSupported(
"BatchNorm folding to scale supports LearnedGridTensorQuantizer only; "
f"got {type(weight_quantizer)}."
)
output_quantizer = conv_wrapper.output_quantizers[0]
if output_quantizer.enabled:
raise _BatchNormFoldingNotSupported(
"BatchNorm should belong to the same supergroup with the layer to be folded to."
)
if "bias" in conv_wrapper.param_quantizers:
bias_quantizer = conv_wrapper.param_quantizers["bias"]
if bias_quantizer.enabled:
raise _BatchNormFoldingNotSupported(
"Can't fold BatchNorm to scale if bias quantizer is enabled."
)
encodings = weight_quantizer.encoding
if encodings is None:
raise RuntimeError
if isinstance(encodings, libpymo.TfEncoding):
encodings = [encodings]
if isinstance(conv, _ConvTransposeNd) and conv.groups != 1:
raise _BatchNormFoldingNotSupported(
"BatchNorm folding to scale is not supported for grouped ConvTransposeNd."
)
# Add quantization noise to the BN params (bn weight & bn bias) before folding.
# NOTE: Quantization of foldable batchnorms is automatically disabled when
# initializing quantsim. However, it is still safer to call _quantize_params here
# as we can't guarantee this is always the case.
# For example, the user can manually enable quantization of batchnorms, etc...
# (FYI: _quantize_params takes effect only when the parameter quantizers are enabled)
with bn_wrapper._quantize_params():
_fold_to_weight(conv, bn, fold_backward=True)
gamma = bn.weight
sigma = torch.sqrt(bn.running_var + bn.eps)
new_encodings = []
for old_encoding, c in zip(encodings, gamma/sigma):
new_encoding = libpymo.TfEncoding()
new_encoding.delta = old_encoding.delta * abs(c)
if c >= 0:
new_encoding.max = old_encoding.max * c
new_encoding.min = old_encoding.min * c
else:
new_encoding.max = old_encoding.min * c
new_encoding.min = old_encoding.max * c
new_encoding.offset = old_encoding.offset
new_encoding.bw = old_encoding.bw
new_encodings.append(new_encoding)
weight_quantizer.encoding = new_encodings
# Copy batchnorm's output quantizers to conv output quantizers
for conv_output_quantizer, bn_output_quantizer in\
zip(conv_wrapper.output_quantizers, bn_wrapper.output_quantizers):
conv_output_quantizer.enabled = bn_output_quantizer.enabled
if bn_output_quantizer.encoding is not None:
encoding = libpymo.TfEncoding()
encoding.delta = bn_output_quantizer.encoding.delta
encoding.max = bn_output_quantizer.encoding.max
encoding.min = bn_output_quantizer.encoding.min
encoding.offset = bn_output_quantizer.encoding.offset
encoding.bw = bn_output_quantizer.encoding.bw
conv_output_quantizer.encoding = encoding
bn_output_quantizer.enabled = False
if "bias" not in conv_wrapper.param_quantizers:
bias_quantizer = LearnedGridTensorQuantizer(weight_quantizer.bitwidth,
weight_quantizer.round_mode,
weight_quantizer.quant_scheme,
weight_quantizer.use_symmetric_encodings,
enabled_by_default=False,
data_type=weight_quantizer.data_type)
bias_quantizer._ch_axis = weight_quantizer._ch_axis
conv_wrapper.param_quantizers["bias"] = bias_quantizer
def _fold_to_weight(conv_linear: LayerType, bn: BatchNormType, fold_backward: bool):
"""
Fold BatchNorm into the weight and bias of the given layer.
:param conv_linear: Conv or linear layer to fold BN into.
:param bn: BatchNorm to fold.
"""
# Transpose weights to C, N, H, W from N, C, H, W since axis are flipped for transposed conv
# However depthwise conv layers are always N, 1, H, W whether transposed-conv or not, so no need to transpose
if isinstance(conv_linear, torch.nn.ConvTranspose2d) and conv_linear.groups == 1:
conv_linear.weight.data = conv_linear.weight.data.permute(1, 0, 2, 3)
if conv_linear.bias is None:
out_channels = conv_linear.out_features if isinstance(conv_linear, torch.nn.Linear)\
else conv_linear.out_channels
bias = torch.zeros(out_channels,
device=conv_linear.weight.device,
dtype=conv_linear.weight.dtype)
conv_linear.bias = torch.nn.Parameter(bias)
if USE_PYTHON_IMPL:
_call_py_batch_norm_fold(conv_linear.weight, conv_linear.bias, bn, fold_backward=fold_backward)
else:
_call_mo_batch_norm_fold(conv_linear.weight, conv_linear.bias, bn, fold_backward=fold_backward)
# Transpose weight back to N, C, H, W for transposed Conv2D, for non-depthwise layers
if isinstance(conv_linear, torch.nn.ConvTranspose2d) and conv_linear.groups == 1:
conv_linear.weight.data = conv_linear.weight.data.permute(1, 0, 2, 3)
[docs]def fold_given_batch_norms(model, layer_pairs):
"""
Fold a given set of batch_norm layers into conv layers
:param model: Model
:param layer_pairs: Pairs of conv and batch_norm layers to use for folding
:return: None
"""
# pylint: disable=protected-access
conv_bn_pairs = []
bn_conv_pairs = []
def is_batchnorm(module: torch.nn.Module) -> bool:
if isinstance(module, QcQuantizeWrapper):
module = module._module_to_wrap
return isinstance(module, _supported_batchnorms)
def is_conv_linear(module: torch.nn.Module) -> bool:
if isinstance(module, QcQuantizeWrapper):
module = module._module_to_wrap
return isinstance(module, _supported_layers)
for x, y in layer_pairs:
if is_batchnorm(x):
assert is_conv_linear(y)
bn = x
conv = y
bn_conv_pairs.append((bn, conv))
else:
assert is_conv_linear(x)
assert is_batchnorm(y)
conv = x
bn = y
conv_bn_pairs.append((conv, bn))
_fold_given_batch_norms(model, conv_bn_pairs, bn_conv_pairs)
def _fold_given_batch_norms(model,
conv_bn_pairs: Iterable[Tuple[torch.nn.Module, torch.nn.Module]],
bn_conv_pairs: Iterable[Tuple[torch.nn.Module, torch.nn.Module]]):
"""
Fold a given set of batch_norm layers into conv layers
:param model: Model
:param conv_bn_pairs: List of (conv, bn) pairs to fold
:param bn_conv_pairs: List of (bn, conv) pairs to fold
:return: None
"""
# pylint: disable=protected-access
for bn, conv in bn_conv_pairs:
if isinstance(conv, QcQuantizeWrapper):
raise RuntimeError(f"Forward folding to scale is not possible. Got {conv}")
bn_modules = []
def _fold(conv, bn, fold_backward):
is_wrapped = isinstance(conv, QcQuantizeWrapper) or isinstance(bn, QcQuantizeWrapper)
try:
if is_wrapped:
assert isinstance(conv, QcQuantizeWrapper) and isinstance(bn, QcQuantizeWrapper)
_fold_to_scale(conv, bn)
bn_modules.append(bn._module_to_wrap)
else:
_fold_to_weight(conv, bn, fold_backward=fold_backward)
except _BatchNormFoldingNotSupported as e:
bn_name = utils.get_layer_name(model, bn)
conv_name = utils.get_layer_name(model, conv)
_logger.warning(
"Failed to fold %s to %s. [Reason] %s", bn_name, conv_name, str(e)
)
else:
bn_modules.append(bn._module_to_wrap if is_wrapped else bn)
with utils.in_eval_mode(model), torch.no_grad():
for conv, bn in conv_bn_pairs:
_fold(conv, bn, fold_backward=True)
for bn, conv in bn_conv_pairs:
_fold(conv, bn, fold_backward=False)
_delete_bn_from_model(model, bn_modules)
def find_all_batch_norms_to_fold(model, input_shapes, dummy_input: Union[torch.Tensor, Tuple] = None):
"""
Find all possible batch norm layers that can be folded. And returns a list of pairs such that (bn, layer)
means bn will be forward-folded into layer and (layer, bn) means bn will be backward-folded into layer
:param model: Model to search
:param input_shapes: Input shapes to use for the model (can be one or multiple inputs)
:param dummy_input: A dummy input to the model. Can be a Tensor or a Tuple of Tensors
:return: List of pairs of bn and layers to fold bn into
"""
device = utils.get_device(model)
if dummy_input is not None:
connected_graph = ConnectedGraph(model, dummy_input)
else:
device = utils.get_device(model)
inp_tensor_list = utils.create_rand_tensors_given_shapes(input_shapes, device)
connected_graph = ConnectedGraph(model, inp_tensor_list)
conv_bn_pairs, bn_conv_pairs, _ = _find_all_batch_norms_to_fold(connected_graph)
return conv_bn_pairs + bn_conv_pairs
def _find_all_batch_norms_to_fold(connected_graph: ConnectedGraph) -> Tuple[
List[Tuple[LayerType, BatchNormType]], List[Tuple[BatchNormType, LayerType]]]:
"""
Find all possible batch norm layers that can be folded. And returns a list of pairs such that (bn, layer)
means bn will be forward-folded into layer and (layer, bn) means bn will be backward-folded into layer
:param connected_graph: Connected graph associated with the model.
:return: A list of (layer, bn) pairs and a list of (bn, layer) pairs,
where `bn` can be folded into to `layer`.
"""
conv_bn_pairs, bn_conv_pairs, bn_to_fold = _find_foldable_bn_pair_and_bn_picked_for_folding(connected_graph)
return conv_bn_pairs, bn_conv_pairs, bn_to_fold
def _find_foldable_bn_pair_and_bn_picked_for_folding(connected_graph: ConnectedGraph) -> Tuple[
List[Tuple[LayerType, BatchNormType]], List[Tuple[BatchNormType, LayerType]], Set]:
"""
Find all possible batch norm layers that can be folded. And returns a list of pairs such that (bn, layer)
means bn will be forward-folded into layer and (layer, bn) means bn will be backward-folded into layer
:param connected_graph: Connected graph associated with the model.
:return: A list of (layer, bn) pairs and a list of (bn, layer) pairs,
where `bn` can be folded into to `layer`.
A set of bn ops which can be folded in to immediate convs.
"""
conv_linear_bn_activation_info_dict = find_all_conv_bn_with_activation_in_graph(connected_graph)
# To mark BN's already picked for backward folding
bn_picked_for_folding = set()
_conv_linear_optypes = CONV_OP_TYPES + LINEAR_OP_TYPES
ordered_conv_fc_modules = [op.get_module() for op in connected_graph.ordered_ops if op.type in _conv_linear_optypes]
conv_bn_pairs = []
# Backward fold is given priority over Forward fold
for module in ordered_conv_fc_modules:
if module in conv_linear_bn_activation_info_dict.keys() and _is_valid_bn_fold(module, True):
bn_info = conv_linear_bn_activation_info_dict[module]
if bn_info.output_bn and bn_info.output_bn not in bn_picked_for_folding:
conv_bn_pairs.append((module, bn_info.output_bn.get_module()))
bn_picked_for_folding.add(bn_info.output_bn)
bn_conv_pairs = []
for module in ordered_conv_fc_modules:
if module in conv_linear_bn_activation_info_dict.keys() and _is_valid_bn_fold(module, False):
bn_info = conv_linear_bn_activation_info_dict[module]
if bn_info.input_bn and bn_info.input_bn not in bn_picked_for_folding:
bn_conv_pairs.append((bn_info.input_bn.get_module(), module))
bn_picked_for_folding.add(bn_info.input_bn)
return conv_bn_pairs, bn_conv_pairs, bn_picked_for_folding
def find_standalone_batchnorm_ops(connected_graph: ConnectedGraph)->set:
"""
Find all batchnorms ops can not be folded.
:param connected_graph: Connected graph associated with the model.
:return stand_alone_bn_ops: Set of batchnorm ops can not be folded.
"""
_, _, bn_picked_for_folding = _find_foldable_bn_pair_and_bn_picked_for_folding(connected_graph)
bn_ops = {op for op in connected_graph.get_all_ops().values() if op.type in BN_OP_TYPES}
stand_alone_bn_ops = bn_ops - bn_picked_for_folding
return stand_alone_bn_ops
def _is_valid_bn_fold(conv: LayerType, fold_backward: bool) -> bool:
"""
Determine if a given layer can successfully absorb a BatchNorm given the layer type and parameters
:param conv: The Conv/Linear layer to fold a BatchNorm into.
:param fold_backward: True if BatchNorm comes after Conv/Linear layer
:return: True if a BatchNorm layer can be folded without causing output error.
"""
valid = True
if not fold_backward:
# Cannot fold BN -> Conv with padding. AIMET does not support forward folding to grouped or DW Conv
if isinstance(conv, (torch.nn.Conv2d, torch.nn.Conv1d, torch.nn.Conv3d)):
valid &= all(item == 0 for item in conv.padding)
valid &= conv.groups == 1
# AIMET does not support forward folding to ConvTranspose
elif isinstance(conv, torch.nn.ConvTranspose2d):
valid = False
else:
# AIMET does not support backwards folding to grouped ConvTranspose
if isinstance(conv, torch.nn.ConvTranspose2d):
valid &= conv.groups in (1, conv.in_channels)
return valid
def fold_all_batch_norms_to_weight(
model: torch.nn.Module,
input_shapes: Union[Tuple, List[Tuple]],
dummy_input: Union[torch.Tensor, Tuple] = None
) -> List[Tuple[LayerType, BatchNormType]]:
"""
Fold all batch_norm layers in a model into the weight of the corresponding conv layers
:param model: Model
:param input_shapes: Input shapes for the model (can be one or multiple inputs)
:param dummy_input: A dummy input to the model. Can be a Tensor or a Tuple of Tensors
:return: A list of pairs of layers [(Conv/Linear, BN layer that got folded)]
"""
if isinstance(model, torch.nn.DataParallel):
return fold_all_batch_norms_to_weight(model.module, input_shapes, dummy_input)
device = utils.get_device(model)
if dummy_input is None:
inp_tensor_list = utils.create_rand_tensors_given_shapes(input_shapes, device)
else:
inp_tensor_list = dummy_input
connected_graph = ConnectedGraph(model, inp_tensor_list)
conv_bn_pairs, bn_conv_pairs, bn_to_fold = _find_all_batch_norms_to_fold(connected_graph)
_fold_given_batch_norms(model, conv_bn_pairs, bn_conv_pairs)
# Convert the standalone BNs which are not folded
bn_converted = convert_standalone_batchnorms(model, inp_tensor_list, bn_to_fold)
_logger.debug("Total %d standalone BatchNorms' weights got converted", len(bn_converted))
return conv_bn_pairs + [(conv, bn) for bn, conv in bn_conv_pairs]
def convert_standalone_batchnorms(model: torch.nn.Module,
dummy_input: Union[torch.Tensor, Tuple],
folded_bn: set) -> List[Tuple[Any, BatchNorm2d]]:
"""
Convert the weights of all the standalone batchnorms of a model which didn't get folded.
:param model: torch model for which batch norm folding is being performed
:param dummy_input: dummy input for the model
:param folded_bn: list of BNs which got folded
:return: List of tuple(name, bn_module) whose weights got converted
"""
module_list = utils.get_ordered_list_of_modules(model, dummy_input)
bn_converted = []
for name, module in module_list:
if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d)) and module not in folded_bn:
convert_batchnorm_parameters(model, module)
_logger.debug("%s weights got converted", name)
bn_converted.append((name, module))
return bn_converted
def convert_batchnorm_parameters(model: torch.nn.Module, bn: Union[torch.nn.BatchNorm1d, torch.nn.BatchNorm2d]):
"""
To convert the weight of a batchnorm such that it becomes in the format y = weights*input + bias
:param model: torch model for which batch norm folding is being performed
:param bn: BatchNorm module whose weights needs to be converted
"""
with utils.in_eval_mode(model), torch.no_grad():
gamma = bn.weight
beta = bn.bias
running_mean = bn.running_mean
inv_sigma = torch.rsqrt(bn.running_var + bn.eps)
weight = gamma * inv_sigma
bias = beta - running_mean * weight
# Update the values
bn.eps = 0
bn.track_running_stats = False
bn.weight.copy_(weight.clone().detach())
bn.bias.copy_(bias.clone().detach())
bn.running_mean = torch.zeros(bn.running_mean.shape, device=bn.running_mean.device, dtype=bn.running_mean.dtype)
bn.running_var = torch.ones(bn.running_var.shape, device=bn.running_var.device, dtype=bn.running_var.dtype)
fold_all_batch_norms = fold_all_batch_norms_to_weight
[docs]def fold_all_batch_norms_to_scale(
sim: QuantizationSimModel,
) -> List[Tuple[QcQuantizeWrapper, QcQuantizeWrapper]]:
"""
Fold all batch_norm layers in a model into the quantization scale parameter
of the corresponding conv layers
:param sim: QuantizationSimModel
:return: A list of pairs of layers [(Conv/Linear, BN layer that got folded)]
"""
# pylint: disable=protected-access
assert sim.model is not None
assert sim.connected_graph is not None
model = sim.model
connected_graph = sim.connected_graph
quant_wrappers = {
quant_wrapper._module_to_wrap: quant_wrapper
for _, quant_wrapper in sim.quant_wrappers()
}
conv_bn_pairs, bn_conv_pairs, _ = _find_all_batch_norms_to_fold(connected_graph)
conv_bn_pairs = [
(quant_wrappers[conv], quant_wrappers[bn]) for conv, bn in conv_bn_pairs
]
bn_conv_pairs = [
(quant_wrappers[bn], quant_wrappers[conv]) for bn, conv in bn_conv_pairs
]
_fold_given_batch_norms(model, conv_bn_pairs, bn_conv_pairs)
return conv_bn_pairs + [(conv, bn) for bn, conv in bn_conv_pairs]
def find_all_conv_bn_with_activation(model: torch.nn.Module, input_shape: Tuple) -> Dict:
"""
Uses searcher to find preceding and next bn layers for a conv/linear layer
:param model: PyTorch model
:param input_shape: shape of input to the model
:return: dictionary of conv/linear layers with associated bn op / activation info
"""
device = utils.get_device(model)
inp_tensor_list = utils.create_rand_tensors_given_shapes(input_shape, device)
connected_graph = ConnectedGraph(model, inp_tensor_list)
return find_all_conv_bn_with_activation_in_graph(connected_graph)
def find_all_conv_bn_with_activation_in_graph(connected_graph: ConnectedGraph) -> Dict:
"""
Uses searcher to find preceding and next bn layers for a conv/linear layer
:param connected_graph: ConnectedGraph object.
:return: dictionary of conv/linear layers with associated bn op / activation info
"""
# initialize all patterns to be matched and associated call back functions
patterns_with_callbacks = []
layer_select_handler = ConvBnPatternHandler()
conv_types = ['Conv1d', 'Conv', 'ConvTranspose']
linear_types = ['Gemm']
for op_type in conv_types + linear_types:
patterns_with_callbacks.append(PatternType(pattern=['BatchNormalization', op_type],
action=layer_select_handler))
patterns_with_callbacks.append(PatternType(pattern=[op_type, 'BatchNormalization'],
action=layer_select_handler))
patterns_with_callbacks.append(PatternType(pattern=['Conv3d', 'BatchNorm3d'], action=layer_select_handler))
patterns_with_callbacks.append(PatternType(pattern=['BatchNorm3d', 'Conv3d'], action=layer_select_handler))
# create graph searcher instance with connected graph and patterns to search
graph_searcher = GraphSearcher(connected_graph, patterns_with_callbacks)
# get all conv/linear and bn info
graph_searcher.find_all_patterns_in_graph_apply_actions()
convs_bn_activation_dict = layer_select_handler.get_conv_linear_bn_info_dict()
return convs_bn_activation_dict