Source code for aimet_torch.cross_layer_equalization

# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2019-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================

"""Cross Layer Equalization

Some terminology for this code.
CLS set: Set of layers (2 or 3) that can be used for cross-layer scaling
Layer groups: Groups of layers that are immediately connected and can be decomposed further into CLS sets
"""
# pylint: disable=too-many-lines

from typing import Tuple, List, Union, Dict
import warnings
import numpy as np
import torch

from aimet_common.utils import AimetLogger, _red
from aimet_common.cross_layer_equalization import (
    ClsLayerType,
    ClsSetInfo,
    ClsImpl,
    CrossLayerScaling as CLS,
    HbfImpl,
)
from aimet_torch import utils
from aimet_torch.meta.connectedgraph import ConnectedGraph
from aimet_torch.batch_norm_fold import fold_all_batch_norms
from aimet_torch.utils import (
    get_device,
    get_ordered_list_of_modules,
    create_rand_tensors_given_shapes,
    place_model,
)

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.CrosslayerEqualization)

ClsSet = Union[
    Tuple[torch.nn.Conv2d, torch.nn.Conv2d],
    Tuple[torch.nn.Conv2d, torch.nn.Conv2d, torch.nn.Conv2d],
]

ClsSupportedLayer = Union[
    torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d
]

ScaleFactor = Union[np.ndarray, Tuple[np.ndarray]]

cls_supported_layers = (
    torch.nn.Conv2d,
    torch.nn.ConvTranspose2d,
    torch.nn.Conv1d,
    torch.nn.ConvTranspose1d,
)
cls_supported_activations = (torch.nn.ReLU, torch.nn.PReLU)


def get_ordered_list_of_conv_modules(
    model: torch.nn.Module, dummy_input: Union[torch.Tensor, Tuple]
) -> List:
    """
    Finds order of nodes in graph
    :param model: model
    :param dummy_input: Dummy input to the model. Used to parse model graph.
    :return: List of names in graph in order
    """
    module_list = get_ordered_list_of_modules(model, dummy_input)
    module_list = [
        [name, module]
        for name, module in module_list
        if isinstance(module, cls_supported_layers)
    ]
    return module_list


class GraphSearchUtils:
    """
    Code to search a model graph to find nodes to use for cross-layer-scaling and high-bias-fold
    """

    def __init__(
        self,
        model: torch.nn.Module,
        input_shapes: Union[Tuple, List[Tuple]],
        dummy_input: Union[torch.Tensor, List[torch.Tensor]] = None,
    ):
        """

        :param model: PyTorch model.
        :param input_shapes: Input shape for the model (can be one or multiple inputs)
        :param dummy_input: Dummy input to the model. Used to parse model graph. dummy_input is expected to be placed
         on the same device as model.
        """
        if dummy_input is None:
            inp_tensor_list = tuple(
                utils.create_rand_tensors_given_shapes(input_shapes, get_device(model))
            )
        else:
            inp_tensor_list = dummy_input
        self._connected_graph = ConnectedGraph(model, inp_tensor_list)
        self._ordered_module_list = get_ordered_list_of_conv_modules(
            model, inp_tensor_list
        )

    @staticmethod
    def find_downstream_layer_groups_to_scale(
        op, layer_groups, current_group=None, visited_nodes=None
    ):
        """
        Recursive function to find cls layer groups downstream from a given op
        :param op: Starting op to search from
        :param layer_groups: Running list of layer groups
        :param current_group: Running current layer group
        :param visited_nodes: Running list of visited nodes (to short-circuit recursion)
        :return: None
        """

        if not visited_nodes:
            visited_nodes = []
        if not current_group:
            current_group = []

        if op in visited_nodes:
            return
        visited_nodes.append(op)
        # print("Visiting node: {}".format(op.dotted_name))

        # If current node is Conv2D, add to the current group
        if op.model_module and isinstance(
            op.model_module.get_module(), cls_supported_layers
        ):
            current_group.append(op.model_module.get_module())

        # Terminating condition for current group
        if not op.model_module or not isinstance(
            op.model_module.get_module(),
            cls_supported_layers + cls_supported_activations,
        ):
            if (len(current_group) > 1) and (current_group not in layer_groups):
                layer_groups.append(current_group)
            current_group = []

        for consumer in op.output_ops:
            GraphSearchUtils.find_downstream_layer_groups_to_scale(
                consumer, layer_groups, current_group, visited_nodes
            )

        # Reached a leaf.. See if the current group has something to grab
        if (len(current_group) > 1) and (current_group not in layer_groups):
            layer_groups.append(current_group)

    @staticmethod
    def convert_layer_group_to_cls_sets(layer_group):
        """
        Helper function to convert a layer group to a list of cls sets
        :param layer_group: Given layer group to generate cls sets
        :return: List of cls sets

        Supported layer combinations for CLS are:
        1. Conv + Conv
        2. DepthwiseConv + Conv
        3. Conv + DepthwiseConv + Conv

        Can be rewritten as,
        Conv
            -> Conv
            -> DepthwiseConv
                -> Conv
        DepthwiseConv
            -> Conv

        If a combination is partially supported, the cls_set is completely omitted and restarted from the next
        supported layer
        For example: Consider Conv + DepthwiseConv + Depthwise(unsupported)
        - Since Depthwise(unsupported) is the last layer encountered, we need to omit all the three layers and restart
        the cls sets from the next supported layer.

        """

        # pylint: disable=too-many-branches
        def convert_to_cls_layer_type(
            layer: ClsSupportedLayer,
        ) -> Tuple[ClsLayerType, ClsSupportedLayer]:
            """
            Given the layer, check if its supported in CLS
            :param layer: layer to check
            :return: Tuple of ClsLayerType and the layer
            """
            if layer.groups == 1:
                layer_type = ClsLayerType.Conv
            elif (
                layer.groups == layer.in_channels
                and layer.in_channels == layer.out_channels
            ):
                # depthwiseConv layer with depth multiplier = 1
                layer_type = ClsLayerType.DepthwiseConv
            else:
                layer_type = ClsLayerType.Unsupported
            return layer_type, layer

        def get_next_layer() -> Union[
            Tuple[ClsLayerType, Union[ClsSupportedLayer, None]]
        ]:
            """
            :return: Tuple of ClsLayerType and the next layer in layer_group
            """
            if not layer_group:
                return ClsLayerType.Unsupported, None
            layer = layer_group.pop(0)
            return convert_to_cls_layer_type(layer)

        cls_sets = []

        first_layer_to_scale = (ClsLayerType.Unsupported, None)
        while layer_group:
            while layer_group and first_layer_to_scale[0] is ClsLayerType.Unsupported:
                first_layer_to_scale = get_next_layer()
                if first_layer_to_scale[0] is ClsLayerType.Unsupported:
                    logger.info(
                        "Layer %s is not supported. Ignoring for cls",
                        first_layer_to_scale[1],
                    )

            second_layer_to_scale = get_next_layer()
            if first_layer_to_scale[0] == ClsLayerType.Conv:
                if second_layer_to_scale[0] == ClsLayerType.Conv:
                    cls_sets.append((first_layer_to_scale[1], second_layer_to_scale[1]))
                    first_layer_to_scale = second_layer_to_scale
                elif second_layer_to_scale[0] == ClsLayerType.DepthwiseConv:
                    if layer_group:
                        # do not pop third layer yet, determine its type and then pop it
                        third_layer_to_scale = convert_to_cls_layer_type(layer_group[0])
                        if third_layer_to_scale[0] == ClsLayerType.Conv:
                            cls_sets.append(
                                (
                                    first_layer_to_scale[1],
                                    second_layer_to_scale[1],
                                    third_layer_to_scale[1],
                                )
                            )
                            # adding third_layer_to_scale for the next round of CLS set determination
                            first_layer_to_scale = get_next_layer()
                        else:
                            # unsupported combination encountered
                            first_layer_to_scale = second_layer_to_scale
                else:
                    logger.info(
                        "Layer %s is not supported. Ignoring for cls",
                        second_layer_to_scale[1],
                    )
                    first_layer_to_scale = (ClsLayerType.Unsupported, None)
            elif first_layer_to_scale[0] == ClsLayerType.DepthwiseConv:
                if second_layer_to_scale[0] == ClsLayerType.Conv:
                    cls_sets.append((first_layer_to_scale[1], second_layer_to_scale[1]))
                first_layer_to_scale = second_layer_to_scale
            else:
                logger.info(
                    "Layer %s is not supported. Ignoring for cls",
                    first_layer_to_scale[1],
                )
                first_layer_to_scale = second_layer_to_scale

        return cls_sets

    def find_layer_groups_to_scale(self) -> List[List[torch.nn.Conv2d]]:
        """
        :return: List of groups of layers. Each group can be independently equalized
        """

        # Find the input node(s) in the graph
        input_nodes = []
        for op in self._connected_graph.get_all_ops().values():
            if op.inputs and op.inputs[0].is_model_input:
                input_nodes.append(op)

        layer_groups = []
        for op in input_nodes:
            self.find_downstream_layer_groups_to_scale(op, layer_groups)

        # Sort the layer groups in order of occurrence in the model
        ordered_layer_groups = []
        for _, module in self._ordered_module_list:
            for layer_group in layer_groups:
                if layer_group[0] is module:
                    ordered_layer_groups.append(layer_group)

        return ordered_layer_groups

    @staticmethod
    def does_module_have_relu_activation(
        connected_graph: ConnectedGraph, module: torch.nn.Module
    ) -> bool:
        """
        Finds if a given module has a ReLU activation
        :param connected_graph: Reference to ConnectedGraph instance
        :param module: PyTorch module to find activation for
        :return: True if module has a relu activation
        """

        for op in connected_graph.get_all_ops().values():
            if op.model_module and op.model_module.get_module() is module:
                assert len(op.output_ops) == 1
                is_relu_activation = isinstance(
                    op.output_ops[0].model_module.get_module(),
                    (torch.nn.ReLU, torch.nn.PReLU),
                )
                return is_relu_activation

        return False

    def is_relu_activation_present_in_cls_sets(self, cls_sets: List[ClsSet]):
        """
        :param cls_sets: CLS sets to find relu activations in
        :return: List of groups of layers. Each group can be independently equalized
        """

        is_relu_activation_in_cls_sets = []
        for cls_set in cls_sets:
            # We need to check activation functions for all layers but the last one in the set
            # Because we are only interested in checking activation functions between the layers we will scale
            cls_set = cls_set[:-1]

            is_relu_activation_in_cls_set = ()
            for module in cls_set:
                is_relu_activation_in_cls_set += (
                    self.does_module_have_relu_activation(
                        self._connected_graph, module
                    ),
                )

            if len(is_relu_activation_in_cls_set) == 1:
                is_relu_activation_in_cls_set = is_relu_activation_in_cls_set[0]

            is_relu_activation_in_cls_sets.append(is_relu_activation_in_cls_set)

        return is_relu_activation_in_cls_sets


class CrossLayerScaling(CLS):
    """
    Code to apply the cross-layer-scaling technique to a model
    """

    def scale_cls_set_with_conv_layers(self, cls_set: ClsSet) -> np.ndarray:
        """
        API to invoke equalize layer params (update for weights and bias is in place)

        :param cls_set: Consecutive Conv layers Tuple whose weights and biases need to be equalized
        :return: Scaling factor S_12 for each conv layer pair: numpy array
        """
        on_gpu = False
        for module in cls_set:
            if not isinstance(module, cls_supported_layers):
                raise ValueError(
                    f"Only Conv or Transposed Conv layers are supported for cross layer equalization."
                    f" Layer class {str(module.__class__)} is not supported."
                )
            if module.weight.is_cuda:
                on_gpu = True
                module.cpu()

        cls_impl = PythonClsImpl()
        scaling_factor = cls_impl.scale_cls_set_with_conv_layers(cls_set)

        if on_gpu:
            for module in cls_set:
                module.to(device="cuda")

        return scaling_factor

    def scale_cls_set_with_depthwise_layers(
        self, cls_set: ClsSet
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        API to invoke equalize layer params for depth wise separable layers(update for weights and bias is in place)

        :param cls_set: Consecutive Conv layers whose weights and biases need to be equalized.
                        Second Conv layer is a depth-wise conv and third conv layer is point-wise conv
        :return: Scaling factors S_12 and S_23 : numpy arrays
        """
        on_gpu = False
        for module in cls_set:
            if not isinstance(module, cls_supported_layers):
                raise ValueError(
                    f"Only Conv or Transposed Conv layers are supported for cross layer equalization."
                    f" Layer class {str(module.__class__)} is not supported."
                )
            if module.weight.is_cuda:
                on_gpu = True
                module.cpu()

        cls_impl = PythonClsImpl()
        scaling_factors = cls_impl.scale_cls_set_with_depthwise_layers(cls_set)

        if on_gpu:
            for module in cls_set:
                module.to(device="cuda")

        return scaling_factors

    @staticmethod
    def scale_model(
        model: torch.nn.Module,
        input_shapes: Union[Tuple, List[Tuple]] = None,
        dummy_input: Union[torch.Tensor, List[torch.Tensor]] = None,
    ) -> List[ClsSetInfo]:
        """
        Uses cross-layer scaling to scale all applicable layers in the given model

        :param model: Model to scale
        :param input_shapes: Input shape for the model (can be one or multiple inputs)
        :param dummy_input: A dummy input to the model. Can be a Tensor or a Tuple of Tensors. dummy_input will be
         placed on CPU if not already.
        :return: CLS information for each CLS set
        """
        if isinstance(model, torch.nn.DataParallel):
            return CrossLayerScaling().scale_model(
                model.module, input_shapes, dummy_input=dummy_input
            )

        # The use of input_shapes will be removed in the future release. It is maintained now for backward compatibility.
        if input_shapes and dummy_input is None:
            dummy_input = create_rand_tensors_given_shapes(
                input_shapes, torch.device("cpu")
            )
        if input_shapes is None and dummy_input is None:
            raise ValueError("Both input_shapes and dummy_input can't be None")

        # Place model and dummy input on the cpu.
        with place_model(model, torch.device("cpu")):
            dummy_input = utils.change_tensor_device_placement(
                dummy_input, device=torch.device("cpu")
            )

            # Find layer groups
            graph_search = GraphSearchUtils(
                model, input_shapes, dummy_input=dummy_input
            )
            layer_groups = graph_search.find_layer_groups_to_scale()

            # Find cls sets from the layer groups
            cls_sets = []
            for layer_group in layer_groups:
                cls_set = GraphSearchUtils.convert_layer_group_to_cls_sets(layer_group)
                cls_sets += cls_set

            # Scale the CLS sets
            scale_factors = CrossLayerScaling().scale_cls_sets(cls_sets)

            # Find if there were relu activations between layers of each cls set
            is_relu_activation_in_cls_sets = (
                graph_search.is_relu_activation_present_in_cls_sets(cls_sets)
            )

            # Convert to a list of cls-set-info elements
            cls_set_info_list = CrossLayerScaling().create_cls_set_info_list(
                cls_sets, scale_factors, is_relu_activation_in_cls_sets
            )
        return cls_set_info_list


class PythonClsImpl(ClsImpl):
    """
    This class implements the CLS algorithm using Python version while following the base Implementation interface.
    """

    def scale_cls_set_with_depthwise_layers(
        self, cls_set
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        API to invoke equalize layer params for depth wise separable layers(update for weights and bias is in place)

        :param cls_set: Consecutive Conv layers whose weights and biases need to be equalized.
                        Second Conv layer is a depth-wise conv and third conv layer is point-wise conv
        :return: Scaling factors S_12 and S_23 : numpy arrays
        """
        weight_0 = self._prepare_params(cls_set[0])
        assert cls_set[1].groups > 1
        weight_1 = self._prepare_params(cls_set[1])
        weight_2 = self._prepare_params(cls_set[2])
        weight_0 = weight_0.numpy()
        weight_1 = weight_1.numpy()
        weight_2 = weight_2.numpy()

        bias_0 = None
        if cls_set[0].bias is not None:
            bias_0 = cls_set[0].bias.detach().cpu().numpy()
        bias_1 = None
        if cls_set[1].bias is not None:
            bias_1 = cls_set[1].bias.detach().cpu().numpy()

        # compute scaling factors and folded parameters.
        s_12, s_23 = self.compute_scaling_params_for_depthwise_conv(
            weight_0, weight_1, weight_2
        )
        _weight_0, _weight_1, _weight_2, _bias_0, _bias_1 = (
            self.fold_scaling_params_for_depthwise_conv(
                weight_0, weight_1, weight_2, bias_0, bias_1, s_12, s_23
            )
        )

        with torch.no_grad():
            self._restore_params(cls_set[0], torch.from_numpy(_weight_0))
            self._restore_params(cls_set[1], torch.from_numpy(_weight_1))
            self._restore_params(cls_set[2], torch.from_numpy(_weight_2))

            if cls_set[0].bias is not None:
                cls_set[0].bias.copy_(
                    torch.from_numpy(_bias_0).reshape_as(cls_set[0].bias)
                ).to(device=cls_set[0].bias.device, dtype=cls_set[0].bias.dtype)
            if cls_set[1].bias is not None:
                cls_set[1].bias.copy_(
                    torch.from_numpy(_bias_1).reshape_as(cls_set[1].bias)
                ).to(device=cls_set[1].bias.device, dtype=cls_set[1].bias.dtype)
        return s_12, s_23

    def scale_cls_set_with_conv_layers(self, cls_set) -> np.ndarray:
        """
        API to invoke equalize layer params for regular conv layers (update for weights and bias is in place)

        :param cls_set: Consecutive Conv layers Tuple whose weights and biases need to be equalized
        :return: Scaling factor S_12 for each conv layer pair: numpy array
        """
        weight_0 = self._prepare_params(cls_set[0])
        weight_1 = self._prepare_params(cls_set[1])
        weight_0 = weight_0.numpy()
        weight_1 = weight_1.numpy()

        bias_0 = None
        if cls_set[0].bias is not None:
            bias_0 = cls_set[0].bias.detach().cpu().numpy()

        # compute scaling factors and folded parameters.
        scale_factor = self.compute_scaling_params_for_conv(weight_0, weight_1)
        _weight_0, _weight_1, _bias_0 = self.fold_scaling_params_for_conv(
            weight_0, weight_1, bias_0, scale_factor
        )

        with torch.no_grad():
            self._restore_params(cls_set[0], torch.from_numpy(_weight_0))
            self._restore_params(cls_set[1], torch.from_numpy(_weight_1))
            if cls_set[0].bias is not None:
                cls_set[0].bias.copy_(
                    torch.from_numpy(_bias_0).reshape_as(cls_set[0].bias)
                ).to(device=cls_set[0].bias.device, dtype=cls_set[0].bias.dtype)
        return scale_factor

    @staticmethod
    def _transpose_tensor(
        module: torch.nn.Module, tensor: torch.Tensor
    ) -> torch.Tensor:
        """
        During preparation:
        For TransposeConv2d, Transpose tensor in the common format [Noc, Nin, Kh, Kw].
        For TransposeConv1d, Transpose tensor in common format [Noc, Nin, K].

        During restoration:
        For TransposeConv2d, Transpose tensor in the original format [Nin, Noc, Kh, Kw].
        For TransposeConv1d, Transpose tensor in back in original format [Nin, Noc, K].

        :param module: Module.
        :param tensor: Input tensor.
        :return: Output tensor.
        """
        if isinstance(module, torch.nn.ConvTranspose2d) and module.groups == 1:
            tensor = tensor.permute(1, 0, 2, 3).contiguous()

        if isinstance(module, torch.nn.ConvTranspose1d) and module.groups == 1:
            tensor = tensor.permute(1, 0, 2).contiguous()
        return tensor

    @staticmethod
    def _make_4d_tensor(module: torch.nn.Module, tensor: torch.Tensor) -> torch.Tensor:
        """
        Return 4 dimensional tensor by adding a dimension on the end if the tensor is not 4d.

        :param module: Module.
        :param tensor: Input tensor.
        :return: Output tensor.
        """
        if isinstance(module, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
            assert len(tensor.shape) == 3, "Module should have 3d weight tensor."
            tensor = torch.unsqueeze(tensor, dim=-1)
        return tensor

    def _prepare_params(self, module: torch.nn.Module) -> torch.Tensor:
        """
        Prepare weight parameters for CLS.

        :param module: PyTorch module.
        :return: Prepared weight.
        """
        weight = module.weight.detach().cpu()
        weight = self._transpose_tensor(module, weight)
        weight = self._make_4d_tensor(module, weight)
        return weight

    def _restore_params(self, module: torch.nn.Module, tensor: torch.Tensor):
        """
        Restore the weight parameters.

        :param module: PyTorch module.
        :param tensor: updated parameters.
        """
        if isinstance(module, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
            tensor = torch.squeeze(tensor, dim=-1)

        _weight_0 = self._transpose_tensor(module, tensor)
        module.weight.copy_(_weight_0.reshape_as(module.weight)).to(
            device=module.weight.device, dtype=module.weight.dtype
        )


class HighBiasFold:
    """
    Code to apply the high-bias-fold technique to a model
    """

    ActivationIsReluForFirstModule = bool
    ScaleForFirstModule = np.ndarray

    @classmethod
    def bias_fold(
        cls,
        cls_set_info_list: List[ClsSetInfo],
        bn_layers: Dict[
            Union[torch.nn.Conv2d, torch.nn.ConvTranspose2d], torch.nn.BatchNorm2d
        ],
    ):
        """
        Folds bias values greater than 3 * sigma to next layer's bias

        :param cls_set_info_list: List of info elements for each cls set
        :param bn_layers: Key: Conv/Linear layer Value: Corresponding folded BN layer
        :return: None
        """
        if not bn_layers:
            logger.info(
                "High Bias folding is not supported for models without BatchNorm Layers"
            )
            return

        for cls_set_info in cls_set_info_list:
            for cls_pair_info in cls_set_info.cls_pair_info_list:
                if (
                    (cls_pair_info.layer1.bias is None)
                    or (cls_pair_info.layer2.bias is None)
                    or (cls_pair_info.layer1 not in bn_layers)
                ):
                    continue

                hbf_impl = PythonHbfImpl()
                hbf_impl.bias_fold(cls_pair_info, bn_layers)


class PythonHbfImpl(HbfImpl):
    """
    This class implements the HBF algorithm using python version while following the base Implementation interface.
    """

    def bias_fold(self, cls_pair_info, bn_layers):
        """
        Bias fold implementation using python version.

        :param cls_pair_info: Layer pairs that were scaled using CLS and related information.
        :param bn_layers: Dictionary with Key being Conv/Linear layer and value being corresponding folded BN layer.
        """
        weight = cls_pair_info.layer2.weight.detach().cpu()
        if isinstance(
            cls_pair_info.layer2, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)
        ):
            weight = torch.unsqueeze(weight, dim=-1)
        # Transpose weights to C, N, H, W from N, C, H, W since axis are flipped for transposed conv
        if (
            isinstance(
                cls_pair_info.layer2,
                (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d),
            )
            and cls_pair_info.layer2.groups == 1
        ):
            weight = weight.permute(1, 0, 2, 3)
        weight = weight.numpy()

        activation_is_relu = cls_pair_info.relu_activation_between_layers

        beta = (
            bn_layers[cls_pair_info.layer1].bias.detach().cpu().numpy()
            / cls_pair_info.scale_factor
        )
        gamma = (
            bn_layers[cls_pair_info.layer1].weight.detach().cpu().numpy()
            / cls_pair_info.scale_factor
        )

        bias_prev_layer = cls_pair_info.layer1.bias.detach().cpu().numpy()
        bias_curr_layer = cls_pair_info.layer2.bias.detach().cpu().numpy()

        # Absorb high biases
        _bias_prev_layer, _bias_curr_layer = self._absorb_bias(
            activation_is_relu, beta, gamma, weight, bias_curr_layer, bias_prev_layer
        )

        with torch.no_grad():
            cls_pair_info.layer1.bias.copy_(
                torch.from_numpy(_bias_prev_layer).reshape_as(cls_pair_info.layer1.bias)
            ).to(
                device=cls_pair_info.layer1.bias.device,
                dtype=cls_pair_info.layer1.bias.dtype,
            )
            cls_pair_info.layer2.bias.copy_(
                torch.from_numpy(_bias_curr_layer).reshape_as(cls_pair_info.layer2.bias)
            ).to(
                device=cls_pair_info.layer2.bias.device,
                dtype=cls_pair_info.layer2.bias.dtype,
            )


def _warn_relu6(model: torch.nn.Module):
    if not any(isinstance(module, torch.nn.ReLU6) for module in model.modules()):
        return

    msg = " ".join(
        [
            "Cross Layer Scaling (CLS) technique works for combination of conv-conv or conv-relu-conv layers.",
            "Specifically, CLS does not work for combination of conv-relu6-conv layers.",
            "For aimet-torch<2.4, AIMET was force changing relu6 layers to relu in the model."
            "Since aimet-torch==2.4, AIMET will no longer do this."
            "As a result, combination of conv-relu6-conv layers won't be scaled."
            "User can modify their model to change relu6 to relu before invoking AIMET,"
            "if this transformation does not impact floating point accuracy of the model.",
        ]
    )
    warnings.warn(_red(msg), DeprecationWarning, stacklevel=3)


[docs] def equalize_model( model: torch.nn.Module, input_shapes: Union[Tuple, List[Tuple]] = None, dummy_input: Union[torch.Tensor, Tuple] = None, ): """ High-level API to perform Cross-Layer Equalization (CLE) on the given model. The model is equalized in place. :param model: Model to equalize :param input_shapes: Shape of the input (can be a tuple or a list of tuples if multiple inputs) :param dummy_input: A dummy input to the model. Can be a Tensor or a Tuple of Tensors. dummy_input will be placed on CPU if not already. """ _warn_relu6(model) if isinstance(model, torch.nn.DataParallel): equalize_model(model.module, input_shapes, dummy_input) else: # The use of input_shapes will be removed in the future release. It is maintained now for backward compatibility. if input_shapes and dummy_input is None: dummy_input = create_rand_tensors_given_shapes( input_shapes, torch.device("cpu") ) if input_shapes is None and dummy_input is None: raise ValueError("Both input_shapes and dummy_input can't be None") # Place model and dummy input on the cpu. with place_model(model, torch.device("cpu")): dummy_input = utils.change_tensor_device_placement( dummy_input, device=torch.device("cpu") ) # fold batchnorm layers and perform CLE on the folded model. folded_pairs = fold_all_batch_norms( model, input_shapes, dummy_input=dummy_input ) equalize_bn_folded_model( model, input_shapes, folded_pairs, dummy_input=dummy_input )
def equalize_bn_folded_model( model: torch.nn.Module, input_shapes: Union[Tuple, List[Tuple]], folded_pairs: List[Tuple[torch.nn.Module, torch.nn.BatchNorm2d]], dummy_input: Union[torch.Tensor, Tuple] = None, ): """ Perform Cross-Layer Scaling (CLS) and High Bias Folding (HBF) on a batchnorm-folded model. The model is equalized in place. :param model: Batchnorm-folded model to equalize :param input_shapes: Shape of the input (can be a tuple or a list of tuples if multiple inputs) :param dummy_input: A dummy input to the model. Can be a Tensor or a Tuple of Tensors. dummy_input will be placed on CPU if not already. :param folded_pairs: List of pairs of folded layers """ _warn_relu6(model) if isinstance(model, torch.nn.DataParallel): equalize_bn_folded_model( model.module, input_shapes, folded_pairs, dummy_input=dummy_input ) else: bn_dict = {} for conv_bn in folded_pairs: bn_dict[conv_bn[0]] = conv_bn[1] with place_model(model, torch.device("cpu")): # perform cross-layer scaling on applicable layer sets cls_set_info_list = CrossLayerScaling().scale_model( model, input_shapes, dummy_input=dummy_input ) # high-bias fold HighBiasFold.bias_fold(cls_set_info_list, bn_dict)