Source code for aimet_torch.model_preparer

# /usr/bin/env python3.5
# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2021, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================

""" Implementation to automatically prepare pytorch models for AIMET features """

import copy
from re import split
from typing import Any, Optional, Dict, Union, List, Callable
import torch
import torch.fx
from aimet_common.utils import AimetLogger
from aimet_torch.utils import in_eval_mode
from aimet_torch.utils import replace_modules_of_type1_with_type2
import aimet_torch.elementwise_ops as elementwise_ops

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.ModelPreparer)

# this is a map of torch.nn.functional type to corresponding module type
functional_op_to_module_map = {
    torch.nn.functional.relu: torch.nn.ReLU,
    torch.nn.functional.gelu: torch.nn.GELU
}

# In this functional --> module map, corresponding model is of type torch.nn and stateful.
functional_with_stateful_api = {
    'relu'          : torch.nn.ReLU,
    'relu6'         : torch.nn.ReLU6,
    'hardtanh'      : torch.nn.Hardtanh,
    'hardwish'      : torch.nn.Hardswish,
    'elu'           : torch.nn.ELU,
    'selu'          : torch.nn.SELU,
    'celu'          : torch.nn.CELU,
    'leaky_relu'    : torch.nn.LeakyReLU,
    'prelu'         : torch.nn.PReLU,
    'rrelu'         : torch.nn.RReLU,
    'glu'           : torch.nn.GLU,
    'gelu'          : torch.nn.GELU,
    'logsigmoid'    : torch.nn.LogSigmoid,
    'hardshrink'    : torch.nn.Hardshrink,
    'tanhshrink'    : torch.nn.Tanhshrink,
    'softsign'      : torch.nn.Softsign,
    'softplus'      : torch.nn.Softplus,
    'softmin'       : torch.nn.Softmin,
    'softmax'       : torch.nn.Softmax,
    'softshrink'    : torch.nn.Softshrink,
    'log_softmax'   : torch.nn.LogSoftmax,
    'tanh'          : torch.nn.Tanh,
    'sigmoid'       : torch.nn.Sigmoid,
    'hardsigmoid'   : torch.nn.Hardsigmoid,
    'silu'          : torch.nn.SiLU,
}


# Function that requires special transformation.
functional_with_special_handling = {
    'cat'           : elementwise_ops.Concat,
    'conv2d'        : torch.nn.Conv2d
}

# In this functional --> module map, corresponding custom module is of type torch.nn and uses stateless API.
functional_with_stateless_api = {
    'add'                       : elementwise_ops.Add,
    'subtract'                  : elementwise_ops.Subtract,
    'sub'                       : elementwise_ops.Subtract,
    'mul'                       : elementwise_ops.Multiply,
    'div'                       : elementwise_ops.Divide,
    'truediv'                   : elementwise_ops.Divide,
    'floordiv'                  : elementwise_ops.FloorDivide,
    'matmul'                    : elementwise_ops.MatMul,
    'exp'                       : elementwise_ops.Exponential,
    'interpolate'               : elementwise_ops.Interpolate,
    'max_pool2d'                : elementwise_ops.MaxPool2d,
    'max_pool2d_with_indices'   : elementwise_ops.MaxPool2d,
    'adaptive_avg_pool2d'       : elementwise_ops.AdaptiveAvgPool2d,
    'avg_pool2d'                : elementwise_ops.AvgPool2d,
    'norm'                      : elementwise_ops.Norm,
    'batch_norm'                : elementwise_ops.BatchNorm,
    'group_norm'                : elementwise_ops.GroupNorm,
    'mean'                      : elementwise_ops.Mean,
    'pow'                       : elementwise_ops.Pow,
    'where'                     : elementwise_ops.Where,
    'addmm'                     : elementwise_ops.Addmm,
    'bmm'                       : elementwise_ops.Bmm,
    'baddbmm'                   : elementwise_ops.Baddbmm,
    'cumsum'                    : elementwise_ops.CumSum,
    'masked_fill'               : elementwise_ops.MaskedFill,
}


def conv2d_create_node(traced_model: torch.fx.GraphModule, module_name: str, node: torch.fx.node) \
        -> torch.fx.node:
    """
    Create the node to be inserted in the graph model.

    :param traced_model: Symbolically traced model
    :param module_name: Qualified module name in symbolic_traced_model hierarchy corresponding to new node
    :param node: Current node in the graph after which new node will be inserted
    :return: torch.fx.node to be inserted in the graph
    """

    n_args = len(node.args)
    # input tensors must be passed as args, not kwargs for QcQuantizeWrapper
    input_tensor = []
    # input and weight is guaranteed to exist, but bias can be None
    # Since None cannot be passed as args in QcQuantizeWrapper, do not add it to input_tensor
    for index, key in [[0, 'input'], [1, 'weight'], [2, ' bias']]:
        value = None
        if n_args > index:
            value = node.args[index]
        elif key in node.kwargs:
            value = node.kwargs[key]

        if value is not None:
            input_tensor.append(value)
        else:
            break

    with traced_model.graph.inserting_after(node):
        if isinstance(getattr(traced_model, module_name), elementwise_ops.DynamicConv2d):
            new_node = traced_model.graph.call_module(module_name, args=tuple(input_tensor))
        else:
            new_node = traced_model.graph.call_module(module_name, args=tuple([input_tensor[0]]))
        return new_node


def conv2d_create_module(node: torch.fx.node) -> torch.nn.Module:
    """
    Create the replacement module.

    :param node: Current node in the graph after which new node will be inserted
    :return: New module.
    """

    # Get weight and bias from argument
    params = merge_args_and_kwargs(node, {1: 'weight', 2: 'bias'})

    # Convert F.Conv2D arguments to nn.Conv2D arguments
    kwargs = merge_args_and_kwargs(node, {3: 'stride', 4: 'padding', 5: 'dilation', 6: 'groups'})

    # If weight or bias is from activation of another layer, use dynamic_conv2d
    use_dynamic_conv2d = False
    for key, param in params.items():
        if param.op != 'get_attr':
            use_dynamic_conv2d = True
            break

    if use_dynamic_conv2d:
        module = elementwise_ops.DynamicConv2d(**kwargs)
    else:
        for key, param_node in params.items():
            params[key] = get_node_attr(param_node)

        # Fetch additional info using parameters
        out_channels, in_channels, kernel_size, _ = params['weight'].shape
        bias = 'bias' in params

        # For Depthwise Conv, multiply in_channels by number of groups
        # if groups is not passed as arg, use its default value 1
        kwargs['in_channels'] = in_channels * kwargs.get('groups', 1)
        kwargs['out_channels'] = out_channels
        kwargs['kernel_size'] = kernel_size
        kwargs['bias'] = bias

        module = torch.nn.Conv2d(**kwargs)
        # Replace nn.Conv2D params using F.Conv2D arguments
        module.weight = params['weight']
        if bias:
            module.bias = params['bias']
    return module


def merge_args_and_kwargs(node: torch.fx.node, arguments_to_fetch: Dict) -> Dict:
    """
    Merge args and kwargs into a single kwargs and return it
    :param node: node to fetch args and kwargs from
    :param arguments_to_fetch: dictionary containing arguments' indices in args and keys in kwargs
    :return: single merged kwargs
    """
    n_args = len(node.args)
    kwargs = {}
    for index, key in arguments_to_fetch.items():
        value = None
        if n_args > index:
            value = node.args[index]
        elif key in node.kwargs:
            value = node.kwargs[key]

        if value is not None:
            kwargs[key] = value
    return kwargs


def get_node_attr(node: torch.fx.node):
    """
    Codes modified from https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern

    :param node: node to fetch data from
    :return: value returned from node
    """
    def fetch_attr(target: str):
        target_atoms = target.split('.')
        attr_itr = node.graph.owning_module
        for i, atom in enumerate(target_atoms):
            if not hasattr(attr_itr, atom):
                raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
            attr_itr = getattr(attr_itr, atom)
        return attr_itr

    assert node.op == 'get_attr'

    return fetch_attr(node.target)


def concat_create_node(traced_model: torch.fx.GraphModule, module_name: str, node: torch.fx.node) \
        -> torch.fx.node:
    """
    Create the node to be inserted in the graph model.

    :param traced_model: Symbolically traced model
    :param module_name: Qualified module name in symbolic_traced_model hierarchy corresponding to new node
    :param node: Current node in the graph after which new node will be inserted
    :return: torch.fx.node to be inserted in the graph
    """

    with traced_model.graph.inserting_after(node):
        # call_module only accepts tuple as args but node.args[0] can be a list. Convert it into a tuple
        # If node.args[0] is already a tuple, tuple() will do nothing
        new_node = traced_model.graph.call_module(module_name, args=tuple(node.args[0]))
        return new_node

def concat_create_module(node: torch.fx.node) -> torch.nn.Module:
    """
    Create the replacement module.

    :param node: Current node in the graph after which new node will be inserted
    :return: New module.
    """

    num_args = len(node.args)
    if num_args == 1 and 'dim' not in node.kwargs:
        # Handle torch.cat being called with default parameter dim
        kwargs = node.kwargs
        module = elementwise_ops.Concat()
    else:
        axis = node.args[1] if num_args > 1 else node.kwargs['dim']
        module = elementwise_ops.Concat(axis)
        kwargs = {'axis': axis}

    for key, value in kwargs.items():
        setattr(module, key, value)

    return module

special_handler_functions = {
    # Special handling functions for creating node and module
    'cat': {'node_fn': concat_create_node, 'module_fn': concat_create_module},
    'conv2d': {'node_fn': conv2d_create_node, 'module_fn': conv2d_create_module}
}

[docs]def prepare_model(model: torch.nn.Module, modules_to_exclude: List[torch.nn.Module] = None, module_classes_to_exclude: List[Callable] = None, concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule: """ Prepare and modify the pytorch model for AIMET features using torch.FX symbolic tracing API. 1. Replace torch.nn.functional by module of type torch.nn.Module 2. Create new independent torch.nn.Module instances for reused/duplicate module :param model: pytorch Model to be modified. :param modules_to_exclude: List of modules to exclude when tracing. :param module_classes_to_exclude: List of module classes to exclude when tracing. :param concrete_args: Allows you to partially specialize your function, whether it's to remove control flow or data structures. If the model has control flow, torch.fx won't be able to trace the model. Check torch.fx.symbolic_trace API in detail. :return: Modified pytorch Model """ with in_eval_mode(model): traced_model = _trace_model(model, modules_to_exclude, module_classes_to_exclude, concrete_args) # Prepare model and perform checks to make sure the graph is well-formed. _prepare_traced_model(traced_model) return traced_model
def _trace_model(model: torch.nn.Module, modules_to_exclude: Optional[List[torch.nn.Module]], module_classes_to_exclude: Optional[List[Callable]], concrete_args: Optional[Dict[str, Any]]): """ Overrides the is_leaf_module() method of parent class when modules_to_exclude list is not None :param model: pytorch Model to be modified. :param modules_to_exclude: List of modules to exclude when tracing. :param module_classes_to_exclude: List of module classes to exclude when tracing. :param concrete_args: Concrete arguments that should not be treated as Proxies. :return: Traced model. """ class Tracer(torch.fx.Tracer): """ Override is_leaf_module() method of parent class. """ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: return ( modules_to_exclude and m in modules_to_exclude or module_classes_to_exclude and type(m) in module_classes_to_exclude # pylint: disable=unidiomatic-typecheck or super(Tracer, self).is_leaf_module(m, module_qualified_name) ) # Symbolic tracing frontend - captures the semantics of the module tracer = Tracer() graph = tracer.trace(model, concrete_args=concrete_args) symbolic_traced_model = torch.fx.GraphModule(tracer.root, graph) return symbolic_traced_model def _prepare_traced_model(traced_model: torch.fx.GraphModule): """ Helper for prepare_model(). This prepares the given traced_model in-place. :param traced_model: Symbolically traced model. """ unique_nodes = set() # Modify the symbolically traced model by iterating over all the nodes for node in traced_model.graph.nodes: # Create new module for functional nodes if node.op in ['call_function', 'call_method']: functional_name = _find_functional_name_for_node(node.name) if functional_name: # Instantiate new module for functional node new_module = _create_module_for_functional_node(node, functional_name) new_nodule_name = 'module_' + node.name setattr(traced_model, new_nodule_name, new_module) # Create the node for new module in the graph _create_node_for_new_module(traced_model, node, new_nodule_name, functional_name) logger.info("Functional : Adding new module for node: {%s} ", node.name) # Create new module for reused/duplicate nodes elif node.target in unique_nodes: if node.op == 'call_module': # Instantiate new module for reused node new_module = _create_module_for_reused_node(node, traced_model) new_nodule_name = 'module_' + node.name setattr(traced_model, new_nodule_name, new_module) # Create the node for new module in the graph _create_node_for_new_module(traced_model, node, new_nodule_name) logger.info("Reused/Duplicate : Adding new module for node: {%s} ", node.name) else: unique_nodes.add(node.target) _verify_traced_model(traced_model) # Replace SiLU with CustomSiLU replace_modules_of_type1_with_type2(traced_model, torch.nn.SiLU, elementwise_ops.CustomSiLU) def _verify_traced_model(traced_model: torch.fx.GraphModule): """ Does some checks to make sure the graph is well-formed and recompile the forward() method of symbolic_traced model from its graph :param traced_model: Symbolically traced model """ traced_model.graph.lint() traced_model.recompile() def _create_node_for_new_module(traced_model: torch.fx.GraphModule, node: torch.fx.node, module_name: str, functional_name: str = None): """ Insert 'call module' node into graph and replace all the uses of 'node' with newly added node and erase the old node from graph :param traced_model: Symbolically traced model :param node: Current node in the graph after which new node will be inserted :param module_name: Qualified module name in symbolic_traced_model hierarchy corresponding to new node :param functional_name: Original functional name :return: None """ with traced_model.graph.inserting_after(node): if functional_name: if functional_name in functional_with_special_handling.keys(): new_node = special_handler_functions[functional_name]['node_fn'](traced_model, module_name, node) elif functional_name in functional_with_stateless_api.keys(): merged_args = _merge_args_and_kwargs(node) new_node = traced_model.graph.call_module(module_name, args=tuple(merged_args)) else: new_node = traced_model.graph.call_module(module_name, args=node.args) else: new_node = traced_model.graph.call_module(module_name, args=node.args) node.replace_all_uses_with(new_node) traced_model.graph.erase_node(node) def _find_functional_name_for_node(node_name: str) -> Union[str, None]: """ For given node name, find corresponding functional name from combined lookup :param node_name: torch.fx Node name :return: corresponding functional name if found, else None """ combined_lookup = {**functional_with_stateful_api, **functional_with_special_handling, **functional_with_stateless_api} # Functional operations with similar names are differentiated using "_count" suffix # when symbolically traced. For example, two add operations will have name 'add' and 'add_1'. # Split given node name by occurrence of pattern. \d is used to match [0-9] followed by '_'. strings = split(pattern=r'_\d', string=node_name) for string in strings: if string in combined_lookup.keys(): return string logger.debug("Couldn't find functional: %s in the lookup. If functional op isn't math invariant," " add an entry in the lookup.", node_name) return None def _create_module_for_functional_node(node: torch.fx.node, functional_name: str) -> torch.nn.Module: """ For given node and functional name, create torch.nn.Module with same parameters as functional node parameters :param node: torch.fx Node :param functional_name: Functional name for given node :return: New module """ # Instantiate new module from lookup if functional_name in functional_with_stateful_api.keys(): module = functional_with_stateful_api[functional_name]() # Set the parameters for module from node.kwargs for key, value in node.kwargs.items(): setattr(module, key, value) elif functional_name in functional_with_special_handling.keys(): module = special_handler_functions[functional_name]['module_fn'](node) elif functional_name in functional_with_stateless_api.keys(): module = functional_with_stateless_api[functional_name]() else: raise ValueError("Unsupported module: {}".format(functional_name)) return module def _create_module_for_reused_node(node: torch.fx.node, symbolic_traced_model: torch.fx.GraphModule) ->\ torch.nn.Module: """ For given reused/Duplicate node in symbolically traced model, create new module with same parameters as original module :param node: Reused/Duplicate torch.fx Node :param symbolic_traced_model: Symbolically traced model :return: New module """ # Get the original module and return newly deep copied module module = _get_module_for_dotted_name(symbolic_traced_model, node.target) new_module = copy.deepcopy(module) return new_module def _get_module_for_dotted_name(module: torch.fx.GraphModule, dotted_name: str) -> torch.nn.Module: """ For given dotted name, find the module :param module: module to be found :param dotted_name: dotted name of module :return: module """ if '.' in dotted_name: module_name, _, remainder = dotted_name.partition('.') return _get_module_for_dotted_name(module._modules[module_name], remainder) # pylint: disable=protected-access return getattr(module, dotted_name) def get_module_for_activation_fn(act_fn: torch.nn.functional): """ returns module instance for functional tyoe handled within PT transformers for activation functions :param act_fn: activation function implemented as a functional. :return: module equivalent for the activation function. """ if act_fn not in functional_op_to_module_map: logger.error("Unsupported activation function {%s}", act_fn) return None module = functional_op_to_module_map[act_fn]() return module def prepare_pt_transformer_for_quantsim(transformer_model: torch.nn.Module): """ Replaces functionals with modules for activation function, updates model in-place :param transformer_model: model with PyTorch nn.Transformer layer :return: updated model with modules for activation function. """ for module in transformer_model.modules(): # encoder layer or decoder layer type is the leaf level node to be updated within nn.transformer layer if isinstance(module, torch.nn.TransformerEncoderLayer) and not isinstance(module.activation, torch.nn.Module): module.activation = get_module_for_activation_fn(module.activation) if isinstance(module, torch.nn.TransformerDecoderLayer) and not isinstance(module.activation, torch.nn.Module): module.activation = get_module_for_activation_fn(module.activation) def _merge_args_and_kwargs(node: torch.fx.node) -> List: """ Merge node's args and kwargs :param node: Torch FX node in the graph whose args and kwargs to be merged. :return: List of merged args. """ merged_args = [] for arg in node.args: merged_args.append(arg) for arg in node.kwargs.values(): merged_args.append(arg) return merged_args