# /usr/bin/env python3.5
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2019, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Optimization code to fold batch-norm layers """
import contextlib
import math
from typing import List, Tuple, Union, Dict, Iterable, Set, Any
import numpy as np
import torch
import torch.nn
from torch.nn.modules.batchnorm import BatchNorm1d, BatchNorm2d
from torch.nn.modules.conv import _ConvTransposeNd
import aimet_common.libpymo as libpymo
from aimet_common.bias_correction import ConvBnPatternHandler, CONV_OP_TYPES, LINEAR_OP_TYPES, BN_OP_TYPES
from aimet_common.graph_pattern_matcher import PatternType
from aimet_common.graph_searcher import GraphSearcher
from aimet_common.utils import AimetLogger
# pylint: disable=unused-import
from aimet_torch.defs import PassThroughOp
from aimet_torch import utils
from aimet_torch.meta.connectedgraph import ConnectedGraph
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.qc_quantize_op import QcQuantizeWrapper
from aimet_torch.tensor_quantizer import LearnedGridTensorQuantizer
_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.BatchNormFolding)
LayerType = Union[
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.ConvTranspose2d,
]
_supported_layers = LayerType.__args__
BatchNormType = Union[BatchNorm1d, BatchNorm2d]
_supported_batchnorms = BatchNormType.__args__
def _delete_bn_from_model(model: torch.nn.Module, bn_layer_list: Iterable[BatchNormType]):
utils.replace_modules_with_instances_of_new_type(model, bn_layer_list, torch.nn.Identity)
@contextlib.contextmanager
def _expand_shape_to_4d(weight_tensor: libpymo.TensorParams):
""" Expand the shape of the weight into 4d. """
dims = len(weight_tensor.shape)
if dims > 5:
raise RuntimeError
if dims == 4:
yield weight_tensor
else:
orig_shape = weight_tensor.shape
if dims < 4:
# If we have less dimensions, we add 1s to make 4 dimensions
_4d_shape = np.append(orig_shape, [1 for _ in range(4-dims)]).astype(int)
else:
# If we have more dimensions, we concatenate all the dimensions beyond 3 into one dimension
_4d_shape = np.array(orig_shape[:3] + [math.prod(orig_shape[3:])])
try:
weight_tensor.shape = _4d_shape
yield weight_tensor
finally:
weight_tensor.shape = orig_shape
def _call_mo_batch_norm_fold(weight: torch.Tensor,
bias: torch.Tensor,
bn: BatchNormType,
fold_backward: bool):
"""
Calls C++ batch norm folding API.
:param weight: Weight or scale tensor to fold BN into.
:param bias: Bias tensor to fold BN into.
:param bn: Batch Norm layer
:param fold_backward: True if BatchNorm comes after Conv/Linear layer
"""
with torch.no_grad():
bn_params = libpymo.BNParams()
bn_params.gamma = bn.weight.detach().cpu().numpy().reshape(-1)
bn_params.beta = bn.bias.detach().cpu().numpy().reshape(-1)
bn_params.runningMean = bn.running_mean.detach().cpu().numpy().reshape(-1)
sigma = torch.sqrt(bn.running_var + bn.eps)
bn_params.runningVar = sigma.detach().cpu().numpy().reshape(-1)
weight_tensor = libpymo.TensorParams()
weight_tensor.data = weight.detach().cpu().numpy().reshape(-1)
weight_tensor.shape = np.array(weight.shape)
bias_tensor = libpymo.TensorParams()
bias_tensor.data = bias.detach().cpu().numpy().reshape(-1)
bias_tensor.shape = np.array(bias.shape)
is_bias_valid = True
with _expand_shape_to_4d(weight_tensor):
_bias = libpymo.fold(bn_params, weight_tensor, bias_tensor, is_bias_valid, fold_backward)
bias.copy_(torch.tensor(_bias, device=bias.device, dtype=bias.dtype)
.reshape_as(bias))
weight.copy_(torch.tensor(weight_tensor.data, device=weight.device, dtype=weight.dtype)
.reshape_as(weight))
class _BatchNormFoldingNotSupported(RuntimeError):
pass
def _fold_to_scale(conv_wrapper: QcQuantizeWrapper, bn_wrapper: QcQuantizeWrapper):
"""
Fold BatchNorm into the scale and bias of the given layer.
:param conv_wrapper: QcQuantizeWrapper that wraps conv or linear layer.
:param bn_wrapper: QcQuantizeWrapper that wraps bn.
"""
# pylint: disable=protected-access, too-many-locals, too-many-branches, bad-whitespace, too-many-statements
conv = conv_wrapper._module_to_wrap
bn = bn_wrapper._module_to_wrap
weight_quantizer = conv_wrapper.param_quantizers["weight"]
if not isinstance(weight_quantizer, LearnedGridTensorQuantizer):
raise _BatchNormFoldingNotSupported(
"BatchNorm folding to scale supports LearnedGridTensorQuantizer only; "
f"got {type(weight_quantizer)}."
)
output_quantizer = conv_wrapper.output_quantizers[0]
if output_quantizer.enabled:
raise _BatchNormFoldingNotSupported(
"BatchNorm should belong to the same supergroup with the layer to be folded to."
)
if "bias" in conv_wrapper.param_quantizers:
bias_quantizer = conv_wrapper.param_quantizers["bias"]
if bias_quantizer.enabled:
raise _BatchNormFoldingNotSupported(
"Can't fold BatchNorm to scale if bias quantizer is enabled."
)
encodings = weight_quantizer.encoding
if encodings is None:
raise RuntimeError
if isinstance(encodings, libpymo.TfEncoding):
encodings = [encodings]
if isinstance(conv, _ConvTransposeNd) and conv.groups != 1:
raise _BatchNormFoldingNotSupported(
"BatchNorm folding to scale is not supported for grouped ConvTransposeNd."
)
# Add quantization noise to the BN params (bn weight & bn bias) before folding.
# NOTE: Quantization of foldable batchnorms is automatically disabled when
# initializing quantsim. However, it is still safer to call _quantize_params here
# as we can't guarantee this is always the case.
# For example, the user can manually enable quantization of batchnorms, etc...
# (FYI: _quantize_params takes effect only when the parameter quantizers are enabled)
with bn_wrapper._quantize_params():
_fold_to_weight(conv, bn, fold_backward=True)
gamma = bn.weight
sigma = torch.sqrt(bn.running_var + bn.eps)
new_encodings = []
for old_encoding, c in zip(encodings, gamma/sigma):
new_encoding = libpymo.TfEncoding()
new_encoding.delta = old_encoding.delta * abs(c)
if c >= 0:
new_encoding.max = old_encoding.max * c
new_encoding.min = old_encoding.min * c
else:
new_encoding.max = old_encoding.min * c
new_encoding.min = old_encoding.max * c
new_encoding.offset = old_encoding.offset
new_encoding.bw = old_encoding.bw
new_encodings.append(new_encoding)
weight_quantizer.encoding = new_encodings
# Copy batchnorm's output quantizers to conv output quantizers
for conv_output_quantizer, bn_output_quantizer in\
zip(conv_wrapper.output_quantizers, bn_wrapper.output_quantizers):
conv_output_quantizer.enabled = bn_output_quantizer.enabled
if bn_output_quantizer.encoding is not None:
encoding = libpymo.TfEncoding()
encoding.delta = bn_output_quantizer.encoding.delta
encoding.max = bn_output_quantizer.encoding.max
encoding.min = bn_output_quantizer.encoding.min
encoding.offset = bn_output_quantizer.encoding.offset
encoding.bw = bn_output_quantizer.encoding.bw
conv_output_quantizer.encoding = encoding
bn_output_quantizer.enabled = False
if "bias" not in conv_wrapper.param_quantizers:
bias_quantizer = LearnedGridTensorQuantizer(weight_quantizer.bitwidth,
weight_quantizer.round_mode,
weight_quantizer.quant_scheme,
weight_quantizer.use_symmetric_encodings,
enabled_by_default=False,
data_type=weight_quantizer.data_type)
bias_quantizer._ch_axis = weight_quantizer._ch_axis
conv_wrapper.param_quantizers["bias"] = bias_quantizer
def _fold_to_weight(conv_linear: LayerType, bn: BatchNormType, fold_backward: bool):
"""
Fold BatchNorm into the weight and bias of the given layer.
:param conv_linear: Conv or linear layer to fold BN into.
:param bn: BatchNorm to fold.
"""
# Transpose weights to C, N, H, W from N, C, H, W since axis are flipped for transposed conv
# However depthwise conv layers are always N, 1, H, W whether transposed-conv or not, so no need to transpose
if isinstance(conv_linear, torch.nn.ConvTranspose2d) and conv_linear.groups == 1:
conv_linear.weight.data = conv_linear.weight.data.permute(1, 0, 2, 3)
if conv_linear.bias is None:
out_channels = conv_linear.out_features if isinstance(conv_linear, torch.nn.Linear)\
else conv_linear.out_channels
bias = torch.zeros(out_channels,
device=conv_linear.weight.device,
dtype=conv_linear.weight.dtype)
conv_linear.bias = torch.nn.Parameter(bias)
_call_mo_batch_norm_fold(conv_linear.weight, conv_linear.bias, bn, fold_backward=fold_backward)
# Transpose weight back to N, C, H, W for transposed Conv2D, for non-depthwise layers
if isinstance(conv_linear, torch.nn.ConvTranspose2d) and conv_linear.groups == 1:
conv_linear.weight.data = conv_linear.weight.data.permute(1, 0, 2, 3)
[docs]def fold_given_batch_norms(model, layer_pairs):
"""
Fold a given set of batch_norm layers into conv layers
:param model: Model
:param layer_pairs: Pairs of conv and batch_norm layers to use for folding
:return: None
"""
# pylint: disable=protected-access
conv_bn_pairs = []
bn_conv_pairs = []
def is_batchnorm(module: torch.nn.Module) -> bool:
if isinstance(module, QcQuantizeWrapper):
module = module._module_to_wrap
return isinstance(module, _supported_batchnorms)
def is_conv_linear(module: torch.nn.Module) -> bool:
if isinstance(module, QcQuantizeWrapper):
module = module._module_to_wrap
return isinstance(module, _supported_layers)
for x, y in layer_pairs:
if is_batchnorm(x):
assert is_conv_linear(y)
bn = x
conv = y
bn_conv_pairs.append((bn, conv))
else:
assert is_conv_linear(x)
assert is_batchnorm(y)
conv = x
bn = y
conv_bn_pairs.append((conv, bn))
_fold_given_batch_norms(model, conv_bn_pairs, bn_conv_pairs)
def _fold_given_batch_norms(model,
conv_bn_pairs: Iterable[Tuple[torch.nn.Module, torch.nn.Module]],
bn_conv_pairs: Iterable[Tuple[torch.nn.Module, torch.nn.Module]]):
"""
Fold a given set of batch_norm layers into conv layers
:param model: Model
:param conv_bn_pairs: List of (conv, bn) pairs to fold
:param bn_conv_pairs: List of (bn, conv) pairs to fold
:return: None
"""
# pylint: disable=protected-access
for bn, conv in bn_conv_pairs:
if isinstance(conv, QcQuantizeWrapper):
raise RuntimeError(f"Forward folding to scale is not possible. Got {conv}")
bn_modules = []
def _fold(conv, bn, fold_backward):
is_wrapped = isinstance(conv, QcQuantizeWrapper) or isinstance(bn, QcQuantizeWrapper)
try:
if is_wrapped:
assert isinstance(conv, QcQuantizeWrapper) and isinstance(bn, QcQuantizeWrapper)
_fold_to_scale(conv, bn)
bn_modules.append(bn._module_to_wrap)
else:
_fold_to_weight(conv, bn, fold_backward=fold_backward)
except _BatchNormFoldingNotSupported as e:
bn_name = utils.get_layer_name(model, bn)
conv_name = utils.get_layer_name(model, conv)
_logger.warning(
"Failed to fold %s to %s. [Reason] %s", bn_name, conv_name, str(e)
)
else:
bn_modules.append(bn._module_to_wrap if is_wrapped else bn)
with utils.in_eval_mode(model), torch.no_grad():
for conv, bn in conv_bn_pairs:
_fold(conv, bn, fold_backward=True)
for bn, conv in bn_conv_pairs:
_fold(conv, bn, fold_backward=False)
_delete_bn_from_model(model, bn_modules)
def find_all_batch_norms_to_fold(model, input_shapes, dummy_input: Union[torch.Tensor, Tuple] = None):
"""
Find all possible batch norm layers that can be folded. And returns a list of pairs such that (bn, layer)
means bn will be forward-folded into layer and (layer, bn) means bn will be backward-folded into layer
:param model: Model to search
:param input_shapes: Input shapes to use for the model (can be one or multiple inputs)
:param dummy_input: A dummy input to the model. Can be a Tensor or a Tuple of Tensors
:return: List of pairs of bn and layers to fold bn into
"""
device = utils.get_device(model)
if dummy_input is not None:
connected_graph = ConnectedGraph(model, dummy_input)
else:
device = utils.get_device(model)
inp_tensor_list = utils.create_rand_tensors_given_shapes(input_shapes, device)
connected_graph = ConnectedGraph(model, inp_tensor_list)
conv_bn_pairs, bn_conv_pairs, _ = _find_all_batch_norms_to_fold(connected_graph)
return conv_bn_pairs + bn_conv_pairs
def _find_all_batch_norms_to_fold(connected_graph: ConnectedGraph) -> Tuple[
List[Tuple[LayerType, BatchNormType]], List[Tuple[BatchNormType, LayerType]]]:
"""
Find all possible batch norm layers that can be folded. And returns a list of pairs such that (bn, layer)
means bn will be forward-folded into layer and (layer, bn) means bn will be backward-folded into layer
:param connected_graph: Connected graph associated with the model.
:return: A list of (layer, bn) pairs and a list of (bn, layer) pairs,
where `bn` can be folded into to `layer`.
"""
conv_bn_pairs, bn_conv_pairs, bn_to_fold = _find_foldable_bn_pair_and_bn_picked_for_folding(connected_graph)
return conv_bn_pairs, bn_conv_pairs, bn_to_fold
def _find_foldable_bn_pair_and_bn_picked_for_folding(connected_graph: ConnectedGraph) -> Tuple[
List[Tuple[LayerType, BatchNormType]], List[Tuple[BatchNormType, LayerType]], Set]:
"""
Find all possible batch norm layers that can be folded. And returns a list of pairs such that (bn, layer)
means bn will be forward-folded into layer and (layer, bn) means bn will be backward-folded into layer
:param connected_graph: Connected graph associated with the model.
:return: A list of (layer, bn) pairs and a list of (bn, layer) pairs,
where `bn` can be folded into to `layer`.
A set of bn ops which can be folded in to immediate convs.
"""
conv_linear_bn_activation_info_dict = find_all_conv_bn_with_activation_in_graph(connected_graph)
# To mark BN's already picked for backward folding
bn_picked_for_folding = set()
_conv_linear_optypes = CONV_OP_TYPES + LINEAR_OP_TYPES
ordered_conv_fc_modules = [op.get_module() for op in connected_graph.ordered_ops if op.type in _conv_linear_optypes]
conv_bn_pairs = []
# Backward fold is given priority over Forward fold
for module in ordered_conv_fc_modules:
if module in conv_linear_bn_activation_info_dict.keys() and _is_valid_bn_fold(module, True):
bn_info = conv_linear_bn_activation_info_dict[module]
if bn_info.output_bn and bn_info.output_bn not in bn_picked_for_folding:
conv_bn_pairs.append((module, bn_info.output_bn.get_module()))
bn_picked_for_folding.add(bn_info.output_bn)
bn_conv_pairs = []
for module in ordered_conv_fc_modules:
if module in conv_linear_bn_activation_info_dict.keys() and _is_valid_bn_fold(module, False):
bn_info = conv_linear_bn_activation_info_dict[module]
if bn_info.input_bn and bn_info.input_bn not in bn_picked_for_folding:
bn_conv_pairs.append((bn_info.input_bn.get_module(), module))
bn_picked_for_folding.add(bn_info.input_bn)
return conv_bn_pairs, bn_conv_pairs, bn_picked_for_folding
def find_standalone_batchnorm_ops(connected_graph: ConnectedGraph)->set:
"""
Find all batchnorms ops can not be folded.
:param connected_graph: Connected graph associated with the model.
:return stand_alone_bn_ops: Set of batchnorm ops can not be folded.
"""
_, _, bn_picked_for_folding = _find_foldable_bn_pair_and_bn_picked_for_folding(connected_graph)
bn_ops = {op for op in connected_graph.get_all_ops().values() if op.type in BN_OP_TYPES}
stand_alone_bn_ops = bn_ops - bn_picked_for_folding
return stand_alone_bn_ops
def _is_valid_bn_fold(conv: LayerType, fold_backward: bool) -> bool:
"""
Determine if a given layer can successfully absorb a BatchNorm given the layer type and parameters
:param conv: The Conv/Linear layer to fold a BatchNorm into.
:param fold_backward: True if BatchNorm comes after Conv/Linear layer
:return: True if a BatchNorm layer can be folded without causing output error.
"""
valid = True
if not fold_backward:
# Cannot fold BN -> Conv with padding. AIMET does not support forward folding to grouped or DW Conv
if isinstance(conv, (torch.nn.Conv2d, torch.nn.Conv1d, torch.nn.Conv3d)):
valid &= all(item == 0 for item in conv.padding)
valid &= conv.groups == 1
# AIMET does not support forward folding to ConvTranspose
elif isinstance(conv, torch.nn.ConvTranspose2d):
valid = False
else:
# AIMET does not support backwards folding to grouped ConvTranspose
if isinstance(conv, torch.nn.ConvTranspose2d):
valid &= conv.groups in (1, conv.in_channels)
return valid
def fold_all_batch_norms_to_weight(
model: torch.nn.Module,
input_shapes: Union[Tuple, List[Tuple]],
dummy_input: Union[torch.Tensor, Tuple] = None
) -> List[Tuple[LayerType, BatchNormType]]:
"""
Fold all batch_norm layers in a model into the weight of the corresponding conv layers
:param model: Model
:param input_shapes: Input shapes for the model (can be one or multiple inputs)
:param dummy_input: A dummy input to the model. Can be a Tensor or a Tuple of Tensors
:return: A list of pairs of layers [(Conv/Linear, BN layer that got folded)]
"""
if isinstance(model, torch.nn.DataParallel):
return fold_all_batch_norms_to_weight(model.module, input_shapes, dummy_input)
device = utils.get_device(model)
if dummy_input is None:
inp_tensor_list = utils.create_rand_tensors_given_shapes(input_shapes, device)
else:
inp_tensor_list = dummy_input
connected_graph = ConnectedGraph(model, inp_tensor_list)
conv_bn_pairs, bn_conv_pairs, bn_to_fold = _find_all_batch_norms_to_fold(connected_graph)
_fold_given_batch_norms(model, conv_bn_pairs, bn_conv_pairs)
# Convert the standalone BNs which are not folded
bn_converted = convert_standalone_batchnorms(model, inp_tensor_list, bn_to_fold)
_logger.info("%d BatchNorms' weights got converted", len(bn_converted))
return conv_bn_pairs + [(conv, bn) for bn, conv in bn_conv_pairs]
def convert_standalone_batchnorms(model: torch.nn.Module,
dummy_input: Union[torch.Tensor, Tuple],
folded_bn: set) -> List[Tuple[Any, BatchNorm2d]]:
"""
Convert the weights of all the standalone batchnorms of a model which didn't get folded.
:param model: torch model for which batch norm folding is being performed
:param dummy_input: dummy input for the model
:param folded_bn: list of BNs which got folded
:return: List of tuple(name, bn_module) whose weights got converted
"""
module_list = utils.get_ordered_list_of_modules(model, dummy_input)
bn_converted = []
for name, module in module_list:
if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d)) and module not in folded_bn:
convert_batchnorm_parameters(model, module)
_logger.debug("%s weights got converted", name)
bn_converted.append((name, module))
return bn_converted
def convert_batchnorm_parameters(model: torch.nn.Module, bn: Union[torch.nn.BatchNorm1d, torch.nn.BatchNorm2d]):
"""
To convert the weight of a batchnorm such that it becomes in the format y = weights*input + bias
:param model: torch model for which batch norm folding is being performed
:param bn: BatchNorm module whose weights needs to be converted
"""
with utils.in_eval_mode(model), torch.no_grad():
gamma = bn.weight
beta = bn.bias
running_mean = bn.running_mean
inv_sigma = torch.rsqrt(bn.running_var + bn.eps)
weight = gamma*inv_sigma
bias = beta - running_mean * weight
# Update the values
bn.eps = 0
bn.track_running_stats = False
bn.weight.copy_(weight.clone().detach())
bn.bias.copy_(bias.clone().detach())
bn.running_mean = torch.zeros(bn.running_mean.shape, device=bn.running_mean.device, dtype=bn.running_mean.dtype)
bn.running_var = torch.ones(bn.running_var.shape, device=bn.running_var.device, dtype=bn.running_var.dtype)
fold_all_batch_norms = fold_all_batch_norms_to_weight
[docs]def fold_all_batch_norms_to_scale(
sim: QuantizationSimModel,
) -> List[Tuple[QcQuantizeWrapper, QcQuantizeWrapper]]:
"""
Fold all batch_norm layers in a model into the quantization scale parameter
of the corresponding conv layers
:param sim: QuantizationSimModel
:return: A list of pairs of layers [(Conv/Linear, BN layer that got folded)]
"""
# pylint: disable=protected-access
assert sim.model is not None
assert sim.connected_graph is not None
model = sim.model
connected_graph = sim.connected_graph
quant_wrappers = {
quant_wrapper._module_to_wrap: quant_wrapper
for _, quant_wrapper in sim.quant_wrappers()
}
conv_bn_pairs, bn_conv_pairs, _ = _find_all_batch_norms_to_fold(connected_graph)
conv_bn_pairs = [
(quant_wrappers[conv], quant_wrappers[bn]) for conv, bn in conv_bn_pairs
]
bn_conv_pairs = [
(quant_wrappers[bn], quant_wrappers[conv]) for bn, conv in bn_conv_pairs
]
_fold_given_batch_norms(model, conv_bn_pairs, bn_conv_pairs)
return conv_bn_pairs + [(conv, bn) for bn, conv in bn_conv_pairs]
def find_all_conv_bn_with_activation(model: torch.nn.Module, input_shape: Tuple) -> Dict:
"""
Uses searcher to find preceding and next bn layers for a conv/linear layer
:param model: PyTorch model
:param input_shape: shape of input to the model
:return: dictionary of conv/linear layers with associated bn op / activation info
"""
device = utils.get_device(model)
inp_tensor_list = utils.create_rand_tensors_given_shapes(input_shape, device)
connected_graph = ConnectedGraph(model, inp_tensor_list)
return find_all_conv_bn_with_activation_in_graph(connected_graph)
def find_all_conv_bn_with_activation_in_graph(connected_graph: ConnectedGraph) -> Dict:
"""
Uses searcher to find preceding and next bn layers for a conv/linear layer
:param connected_graph: ConnectedGraph object.
:return: dictionary of conv/linear layers with associated bn op / activation info
"""
# initialize all patterns to be matched and associated call back functions
patterns_with_callbacks = []
layer_select_handler = ConvBnPatternHandler()
conv_types = ['Conv1d', 'Conv', 'ConvTranspose']
linear_types = ['Gemm']
for op_type in conv_types + linear_types:
patterns_with_callbacks.append(PatternType(pattern=['BatchNormalization', op_type],
action=layer_select_handler))
patterns_with_callbacks.append(PatternType(pattern=[op_type, 'BatchNormalization'],
action=layer_select_handler))
patterns_with_callbacks.append(PatternType(pattern=['Conv3d', 'BatchNorm3d'], action=layer_select_handler))
patterns_with_callbacks.append(PatternType(pattern=['BatchNorm3d', 'Conv3d'], action=layer_select_handler))
# create graph searcher instance with connected graph and patterns to search
graph_searcher = GraphSearcher(connected_graph, patterns_with_callbacks)
# get all conv/linear and bn info
graph_searcher.find_all_patterns_in_graph_apply_actions()
convs_bn_activation_dict = layer_select_handler.get_conv_linear_bn_info_dict()
return convs_bn_activation_dict