Source code for aimet_torch.v1.nn.modules.custom

# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================

""" Custom modules for functional operations defined under torch and torch.nn.functional packages """

from typing import Callable, Any, Tuple, Union, List, Type

import torchvision
import torch
import torch.nn
import spconv.pytorch as spconv

# pylint: disable=no-self-use

def forward_function_wrapper(functional: Callable) -> Any:
    """
    Wrapper function returning forward method for given functional operation.

    :param functional: torch.nn.functional
    :return: forward method
    """
    def forward(self, *args, **kwargs) -> Any: # pylint: disable=unused-argument
        """
        Forward-pass routine for the functional operation.
        """
        return functional(*args, **kwargs)

    return forward

def create_wrapper_module(class_name: str, functional: Callable) -> Type[torch.nn.Module]:
    """
    Dynamically create wrapper module for a functional operation.

    :param class_name: Name of the class.
    :param functional: Functional operation.
    """
    wrapped_module = type(class_name, (torch.nn.Module,), {'forward': forward_function_wrapper(functional)})
    return wrapped_module


# modules for functional operations under torch package
Subtract = create_wrapper_module('Subtract', torch.sub)
Divide = create_wrapper_module('Divide', torch.div)
FloorDivide = create_wrapper_module('FloorDivide', torch.floor_divide)
MatMul = create_wrapper_module('MatMul', torch.matmul)
Norm = create_wrapper_module('Norm', torch.norm)
Exponential = create_wrapper_module('Exponential', torch.exp)
Erf = create_wrapper_module('Erf', torch.erf)
Sqrt = create_wrapper_module('Sqrt', torch.sqrt)
Maximum = create_wrapper_module('Maximum', torch.maximum)
Max = create_wrapper_module('Max', torch.max) # NOTE: Not elementwise
AMax = create_wrapper_module('AMax', torch.amax)
Minimum = create_wrapper_module('Minimum', torch.minimum)
Min = create_wrapper_module('Min', torch.min) # NOTE: Not elementwise
AMin = create_wrapper_module('AMin', torch.amin)
Where = create_wrapper_module('Where', torch.where)
Greater = create_wrapper_module('Greater', torch.gt)
Less = create_wrapper_module('Less', torch.lt)
GreaterEqual = create_wrapper_module('GreaterEqual', torch.ge)
LessEqual = create_wrapper_module('LessEqual', torch.le)
NotEqual = create_wrapper_module('NotEqual', torch.ne)
Equal = create_wrapper_module('Equal', torch.eq)
Bmm = create_wrapper_module('Bmm', torch.bmm)
CumSum = create_wrapper_module('CumSum', torch.cumsum)
MaskedFill = create_wrapper_module('MaskedFill', torch.Tensor.masked_fill_)
Mean = create_wrapper_module('Mean', torch.mean)
Sum = create_wrapper_module('Sum', torch.sum)
Prod = create_wrapper_module('Prod', torch.prod)
Log = create_wrapper_module('Log', torch.log)
Abs = create_wrapper_module('Abs', torch.abs)
Neg = create_wrapper_module('Neg', torch.neg)
Argmin = create_wrapper_module('Argmin', torch.argmin)
Argmax = create_wrapper_module('Argmax', torch.argmax)
ElementwiseCeil = create_wrapper_module('ElementwiseCeil', torch.ceil)
ElementwiseFloor = create_wrapper_module('ElementwiseFloor', torch.floor)
Sin = create_wrapper_module('Sin', torch.sin)
Cos = create_wrapper_module('Cos', torch.cos)
Asin = create_wrapper_module('Asin', torch.asin)
Atan = create_wrapper_module('Atan', torch.atan)
Round = create_wrapper_module('Round', torch.round)
Gather = create_wrapper_module('Gather', torch.gather)
LogicalOr = create_wrapper_module('LogicalOr', torch.logical_or)
LogicalAnd = create_wrapper_module('LogicalAnd', torch.logical_and)
LogicalNot = create_wrapper_module('LogicalNot', torch.logical_not)
Split = create_wrapper_module('Split', torch.split)
Reshape = create_wrapper_module('Reshape', torch.reshape)
Permute = create_wrapper_module('Permute', torch.permute)
Remainder = create_wrapper_module('Remainder', torch.remainder)
IndexSelect = create_wrapper_module('IndexSelect', torch.index_select)
Fmod = create_wrapper_module('Fmod', torch.fmod)
NonZero = create_wrapper_module('NonZero', torch.nonzero)
TopK = create_wrapper_module('TopK', torch.topk)
Shape = create_wrapper_module('Shape', torch.Tensor.size)
Tile = create_wrapper_module('Tile', torch.tile)
ElementwiseUnarySign = create_wrapper_module('ElementwiseUnarySign', torch.sign)
Baddbmm = create_wrapper_module('Baddbmm', torch.baddbmm)
Addmm = create_wrapper_module('Addmm', torch.addmm)
RSqrt = create_wrapper_module('RSqrt', torch.rsqrt)
Square = create_wrapper_module('Square', torch.square)
Select = create_wrapper_module('Select', torch.select)

# modules for functional operations defined under torch.nn.functional package
Interpolate = create_wrapper_module('Interpolate', torch.nn.functional.interpolate)
MaxPool2d = create_wrapper_module('MaxPool2d', torch.nn.functional.max_pool2d)
AdaptiveAvgPool2d = create_wrapper_module('AdaptiveAvgPool2d', torch.nn.functional.adaptive_avg_pool2d)
AvgPool2d = create_wrapper_module('AvgPool2d', torch.nn.functional.avg_pool2d)
BatchNorm = create_wrapper_module('BatchNorm', torch.nn.functional.batch_norm)
GroupNorm = create_wrapper_module('GroupNorm', torch.nn.functional.group_norm)
Normalize = create_wrapper_module('Normalize', torch.nn.functional.normalize)
Pad = create_wrapper_module('Pad', torch.nn.functional.pad)
GridSample = create_wrapper_module('GridSample', torch.nn.functional.grid_sample)

# following modules are for overloaded operators like + and *,
# which can operate other than torch.Tensor datatype.
class Add(torch.nn.Module):
    """ Add module for a functional add"""
    # pylint:disable=arguments-differ
    def forward(self, x: Any, y: Any) -> Any:
        """
        Forward-pass routine for add op
        """
        if isinstance(x, torch.Tensor) or isinstance(y, torch.Tensor):
            out = torch.add(x, y)
        else:
            out = x + y
        return out

class Multiply(torch.nn.Module):
    """ Multiply module for a functional multiply"""
    # pylint:disable=arguments-differ
    def forward(self, x: Any, y: Any) -> Any:
        """
        Forward-pass routine for multiply op
        """
        if isinstance(x, torch.Tensor) or isinstance(y, torch.Tensor):
            out = torch.mul(x, y)
        else:
            out = x * y
        return out


# modules for functional requiring special handling
class Concat(torch.nn.Module):
    """ Concat module for a functional concat"""
    def __init__(self, axis: int = 0):
        super().__init__()
        self._axis = axis

    # pylint:disable=arguments-differ
    def forward(self, *x) -> torch.Tensor:
        """
        Forward-pass routine for cat op
        """
        return torch.cat(x, dim=self._axis)


class DynamicConv2d(torch.nn.Module):
    """ Conv2d module for a functional conv2d"""
    def __init__(self, stride=1, padding=0, dilation=1, groups=1):
        super().__init__()
        self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups

    def forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> torch.Tensor:
        """
        Forward-pass routine for conv2d op
        """
        return torch.nn.functional.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)


class Pow(torch.nn.Module):
    """ Pow module for a functional pow """
    # pylint:disable=arguments-differ
    def forward(self, x: Any, y: Any) -> Any:
        """
        Forward-pass routine for Pow op
        """
        return x ** y


class CustomSiLU(torch.nn.Module):
    """ SiLU as Sigmoid + mul """
    def __init__(self):
        super().__init__()
        self.sigmoid = torch.nn.Sigmoid()
        self.mul = Multiply()

    def forward(self, x: torch.Tensor) -> Any:
        """
        Forward-pass routine for custom SiLU
        """
        return self.mul(x, self.sigmoid(x))


class StridedSlice(torch.nn.Module):
    """Custom module for a functional slice"""
    def forward(self, *args) -> torch.Tensor:
        """
        Forward-pass routine for StridedSlice op
        """
        tensor, slice_ranges = args
        slice_params = []
        for slice_range in slice_ranges:
            slice_params.append(slice(*slice_range))
        return tensor[slice_params]


class ChannelShuffle(torch.nn.Module):
    """Custom module for a ChannelShuffle op"""
    def __init__(self, groups):
        super().__init__()
        self.groups = groups

    def forward(self, *args) -> torch.Tensor:
        """
        Forward-pass routine for ChannelShuffle op
        """
        tensor = args[0]
        n, c, h, w = tensor.shape
        return tensor.view(n, self.groups, c // self.groups, h, w).transpose(1, 2).contiguous().view(n, -1, h, w)


class Cast(torch.nn.Module):
    """ Cast module for a functional cast"""
    def __init__(self, dtype):
        super().__init__()
        self.dtype = dtype

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward-pass routine for cast op
        """
        return x.type(self.dtype)


class CustomGather(torch.nn.Module):
    """ Custom module for ONNX Gather """
    def forward(self, data: torch.Tensor, indices: torch.Tensor, axis: int = 0) -> torch.Tensor:
        """
        Forward-pass routine for ONNX Gather op
        """
        target_shape = data.shape[:axis] + indices.shape + data.shape[axis + 1:]
        indices = (indices < 0).to(indices.dtype) * data.shape[axis] + indices
        return torch.index_select(data, axis, indices.flatten()).reshape(target_shape)


class DepthToSpaceCRDMode(torch.nn.Module):
    """ Depthtospace op implementation in CRD mode """

    def __init__(self, block_size: List):
        super().__init__()
        self.block_size_h = block_size[0]
        self.block_size_w = block_size[1]

    def forward(self, x: torch.Tensor) -> Any:
        """
        Forward-pass routine for DepthToSpace op in CRD mode
        """
        b, c, h, w = x.shape
        tmp = torch.reshape(x, (b, c // (self.block_size_h * self.block_size_w), self.block_size_h, self.block_size_w, h, w))
        tmp = torch.permute(tmp, (0, 1, 4, 2, 5, 3))
        out = torch.reshape(tmp, (b, c // (self.block_size_h * self.block_size_w), h * self.block_size_h, w * self.block_size_w))
        return out

class DepthToSpaceDCRMode(torch.nn.Module):
    """ Depthtospace op implementation in DCR mode """

    # This class is created because Pytorch as of now doesn't have option
    # to run DCR mode in PixelShuffle op.
    def __init__(self, block_size: int):
        super().__init__()
        self.block_size = block_size

    def forward(self, x: torch.Tensor) -> Any:
        """
        Forward-pass routine for DepthToSpace op in DCR mode
        """
        b, c, h, w = x.shape
        blocksize = self.block_size
        tmp = torch.reshape(x, (b, blocksize, blocksize, c // (blocksize**2), h, w))
        tmp = torch.permute(tmp, (0, 3, 4, 1, 5, 2))
        out = torch.reshape(tmp, (b, c // (blocksize**2), h * blocksize, w * blocksize))
        return out

# pylint: disable=abstract-method, arguments-differ, unused-argument
class CustomSparseConv3d(torch.autograd.Function):
    '''
    Custom Sparse Conv3d autograd function
    '''
    @staticmethod
    def symbolic(g, dense_inputs, weight, bias, all_sp_conv_attrs):
        '''
        Symbolic method (static) for Custom sparse Conv3d
        :param g: ONNX graph object
        :param dense_inputs: Dense inputs
        :param weight: weight value
        :param bias: bias value
        :param all_sp_conv_attrs: spconv attributes
        :return: Added op to the graph object
        '''
        attrs = {}
        for k, v in all_sp_conv_attrs.items():
            if v:
                if isinstance(v, str):
                    attrs[k+"_s"] = v
                else:
                    attrs[k+"_i"] = v
        if bias:
            return g.op("spconv::SparseConvolution", dense_inputs, weight, bias, **attrs)
        return g.op("spconv::SparseConvolution", dense_inputs, weight, **attrs)

    @staticmethod
    def forward(ctx, dense_inputs, weight, bias, all_sp_conv_attrs):
        '''
        forward method (static) for Custom sparse Conv3d
        :param ctx: context object
        :param dense_inputs: Dense inputs
        :param weight: weight value
        :param bias: bias value
        :param all_sp_conv_attrs: spconv attributes
        :return: Dense tensor
        '''
        device = weight.device
        dense_inputs = dense_inputs.to(device)
        sp_conv_attrs = dict()
        ignore = ['ndim', 'output_bound', 'input_spatial_shape', 'activation', 'subm', 'batch_size', 'spatial_shape',
                  'input_shape', 'inverse', 'transposed', 'rulebook', 'output_shape', 'output_spatial_shape',
                  'output_padding']
        for k, v in all_sp_conv_attrs.items():
            if k in ignore:
                continue
            sp_conv_attrs[k] = v
        sp_conv_attrs['bias'] = sp_conv_attrs.get("bias", False)
        conv3d = torch.nn.Conv3d(**sp_conv_attrs)

        with torch.no_grad():
            conv3d.weight.copy_(weight.detach().permute(0, 4, 1, 2, 3))
            if sp_conv_attrs['bias']:
                conv3d.bias.copy_(bias.detach())
        conv3d = conv3d.to(device)

        out = conv3d(dense_inputs)
        return out

class CustomSparseConv3d_WithIndicesFeatures(torch.autograd.Function):
    '''
    Custom Sparse Conv3d (with indices and features as inputs) autograd function
    '''
    @staticmethod
    def symbolic(g, indices, features, weight, bias, all_sp_conv_attrs):
        '''
        Symbolic method (static) for Custom sparse Conv3d (with indices and features as inputs)
        :param g: ONNX graph object
        :param indices: Indices input
        :param features: Features input
        :param weight: weight value
        :param bias: bias value
        :param all_sp_conv_attrs: spconv attributes
        :return: Added op to the graph object
        '''
        remove = ['spatial_shape', 'batch_size']
        attrs = {}
        for k, v in all_sp_conv_attrs.items():
            if k not in remove and v:
                if isinstance(v, str):
                    attrs[k+"_s"] = v
                else:
                    attrs[k+"_i"] = v
        if bias:
            return g.op("spconv::SparseConvolution", indices, features, weight, bias, **attrs)
        return g.op("spconv::SparseConvolution", indices, features, weight, **attrs)

    @staticmethod
    def forward(ctx, indices, features, weight, bias, all_sp_conv_attrs):
        '''
        forward method (static) for Custom sparse Conv3d (with indices and features as inputs)
        :param ctx: context object
        :param indices: Indices input
        :param features: Features input
        :param weight: weight value
        :param bias: bias value
        :param all_sp_conv_attrs: spconv attributes
        :return: Dense tensor
        '''
        device = weight.device
        indices = indices.to(device)
        features = features.to(device)
        sp_conv_attrs = dict()
        ignore = ['ndim', 'output_bound', 'input_spatial_shape', 'activation', 'subm', 'batch_size', 'spatial_shape',
                  'input_shape', 'inverse', 'transposed', 'rulebook', 'output_shape', 'output_spatial_shape',
                  'output_padding']
        for k, v in all_sp_conv_attrs.items():
            if k in ignore:
                continue
            sp_conv_attrs[k] = v
        sp_conv_attrs['bias'] = sp_conv_attrs.get("bias", False)
        conv3d = torch.nn.Conv3d(**sp_conv_attrs)

        with torch.no_grad():
            conv3d.weight.copy_(weight.detach().permute(0, 4, 1, 2, 3))
            if sp_conv_attrs['bias']:
                conv3d.bias.copy_(bias.detach())
        conv3d = conv3d.to(device)

        dense_inputs = features.reshape(all_sp_conv_attrs['batch_size'], features.shape[1],
                                        *all_sp_conv_attrs['spatial_shape'])
        dense_inputs = dense_inputs.to(device)

        out = conv3d(dense_inputs)
        return out

# pylint: disable=too-many-arguments, super-with-arguments
class CustomSparseConv3DLayer(torch.nn.Module):
    '''
    SparseConv3D op implementation
    '''
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(CustomSparseConv3DLayer, self).__init__()
        activation = "None" #"ReLU"
        self.sp_conv_3d = spconv.SparseConv3d(in_channels=in_channels, out_channels=out_channels,
                                              kernel_size=kernel_size, bias=bias, stride=stride, padding=padding,
                                              dilation=dilation, groups=1, algo=spconv.ConvAlgo.Native) # doesn't support groups as of now
        self.bias_available = bias
        if not bias:
            with torch.no_grad():
                self.sp_conv_3d.bias = torch.nn.Parameter(torch.zeros(out_channels))
        self.conv_attrs_dict = dict(in_channels=self.sp_conv_3d.in_channels,
                                    out_channels=self.sp_conv_3d.out_channels,
                                    kernel_size=self.sp_conv_3d.kernel_size,
                                    stride=self.sp_conv_3d.stride,
                                    padding=self.sp_conv_3d.padding,
                                    dilation=self.sp_conv_3d.dilation,
                                    subm=int(self.sp_conv_3d.subm),
                                    ndim=self.sp_conv_3d.ndim,
                                    output_bound=20000,
                                    activation=activation,
                                    groups=groups)

    def forward_with_indices_features(self, indices, features):
        '''
        forward with indices and features as inputs
        :param indices: Indices input
        :param features: Features input
        :return: Dense tensor output
        '''
        spatial_shape = [indices[:, 1].max().item()+1, indices[:, 2].max().item()+1,
                         indices[:, 3].max().item()+1]
        batch_size = indices[:, 0].max().item()+1
        if torch.jit.is_tracing():
            self.conv_attrs_dict['spatial_shape'] = spatial_shape
            self.conv_attrs_dict['batch_size'] = batch_size
            self.conv_attrs_dict['input_spatial_shape'] = spatial_shape
            self.conv_attrs_dict['output_bound'] = features.shape[0]
            self.conv_attrs_dict['input_shape'] = features.shape
            self.conv_attrs_dict['rulebook'] = "subm" + str(self.conv_attrs_dict['subm'])
            self.conv_attrs_dict['transposed'] = 0
            self.conv_attrs_dict['inverse'] = 0

            self.conv_attrs_dict = dict(sorted(self.conv_attrs_dict.items(), key=lambda x: (x[0], x[1])))
            return CustomSparseConv3d_WithIndicesFeatures.apply(indices, features, self.sp_conv_3d.weight,
                                                                self.sp_conv_3d.bias, self.conv_attrs_dict)

        sp_tensor = spconv.SparseConvTensor(features=features, indices=indices, spatial_shape=spatial_shape,
                                            batch_size=batch_size)
        saved_bias_zero = self.sp_conv_3d.bias
        if not self.bias_available:
            self.sp_conv_3d.bias = None
        sp_conv_outs = self.sp_conv_3d(sp_tensor)
        dense_outs = sp_conv_outs.dense()
        if not self.bias_available:
            self.sp_conv_3d.bias = saved_bias_zero
        return dense_outs

    def forward_with_dense_input(self, dense_inp):
        """
        Forward-pass routine for SparseConv3D op
        """
        if isinstance(dense_inp, (tuple, list)) and len(dense_inp) == 2:
            return self.forward_with_indices_features(*tuple(dense_inp))

        if isinstance(dense_inp, spconv.SparseConvTensor):
            dense_inp = dense_inp.dense(channels_first=True)

        if torch.jit.is_tracing():
            self.conv_attrs_dict['input_spatial_shape'] = dense_inp.shape[2:]
            self.conv_attrs_dict['spatial_shape'] = dense_inp.shape[2:]
            self.conv_attrs_dict['batch_size'] = dense_inp.shape[0]
            self.conv_attrs_dict['output_bound'] = dense_inp.shape[0] * dense_inp.shape[2] * dense_inp.shape[3] * \
                                                   dense_inp.shape[4]
            self.conv_attrs_dict['input_shape'] = [self.conv_attrs_dict['output_bound'], dense_inp.shape[1]]
            self.conv_attrs_dict['rulebook'] = "subm" + str(self.conv_attrs_dict['subm'])
            self.conv_attrs_dict['transposed'] = 0
            self.conv_attrs_dict['inverse'] = 0

            self.conv_attrs_dict = dict(sorted(self.conv_attrs_dict.items(), key=lambda x: (x[0], x[1])))
            return CustomSparseConv3d.apply(dense_inp, self.sp_conv_3d.weight, self.sp_conv_3d.bias, self.conv_attrs_dict)

        # Dense to Sparse Conversion
        dense_inp = dense_inp.permute(0, 2, 3, 4, 1) # N D H W C
        indices = torch.stack(torch.meshgrid(torch.arange(dense_inp.shape[0]), torch.arange(dense_inp.shape[1]),
                                             torch.arange(dense_inp.shape[2]), torch.arange(dense_inp.shape[3]),
                                             indexing='ij'), dim=-1).reshape(-1, 4).int()
        features = dense_inp.reshape(-1, dense_inp.shape[4])
        spatial_shape = dense_inp.shape[1:-1]
        batch_size = dense_inp.shape[0]
        sp_tensor = spconv.SparseConvTensor(features=features, indices=indices, spatial_shape=spatial_shape,
                                            batch_size=batch_size)

        saved_bias_zero = self.sp_conv_3d.bias
        if not self.bias_available:
            self.sp_conv_3d.bias = None
        sp_conv_outs = self.sp_conv_3d(sp_tensor)
        dense_outs = sp_conv_outs.dense()
        if not self.bias_available:
            self.sp_conv_3d.bias = saved_bias_zero
        return dense_outs

    def forward(self, *args):
        '''
        Forward pass for Custom SparseConv3d layer
        :param args: Either one dense input of format NCDHW or two inputs (indices, features) both in dense form
        :return: Dense tensor
        '''
        if len(args) == 2:
            return self.forward_with_indices_features(*args)
        return self.forward_with_dense_input(*args)

# pylint: disable=useless-super-delegation
[docs]class SparseTensorWrapper(torch.nn.Module): ''' Custom SparsetensorWrapper class for SparseConvTensor ''' def __init__(self): super(SparseTensorWrapper, self).__init__() def forward_with_indices_and_features(self, coords, voxels): ''' forward pass with indices and features as inputs :param coords: Indices input :param voxels: Features input :return: Sparse tensor ''' # dense_inp is expected to be in N C D H W format if torch.jit.is_tracing(): return coords, voxels spatial_shape = [coords[:, 1].max()+1, coords[:, 2].max()+1, coords[:, 3].max()+1] return spconv.SparseConvTensor( features=voxels, indices=coords, spatial_shape=spatial_shape, batch_size=coords[:, 0].max()+1 ) def forward_with_dense_input(self, dense_inp): ''' forward pass with single dense input (NCDHW format) :param dense_inp: Dense input :return: Sparse tensor ''' if isinstance(dense_inp, tuple) and len(dense_inp) == 2: return self.forward_with_indices_and_features(*dense_inp) # dense_inp is expected to be in N C D H W format if torch.jit.is_tracing(): return dense_inp dense_inp = dense_inp.permute(0, 2, 3, 4, 1) # Considering all indices as dense indices = torch.stack(torch.meshgrid(torch.arange(dense_inp.shape[0]), torch.arange(dense_inp.shape[1]), torch.arange(dense_inp.shape[2]), torch.arange(dense_inp.shape[3]), indexing='ij'), dim=-1).reshape(-1, 4).int() features = dense_inp.reshape(-1, dense_inp.shape[4]) spatial_shape = dense_inp.shape[1:-1] return spconv.SparseConvTensor( features=features, indices=indices, spatial_shape=spatial_shape, batch_size=dense_inp.shape[0] ) def forward(self, *args): ''' Forward pass for SparseConvTensor's custom implementation :param args: Either one dense input of format NCDHW or two inputs (indices, features) both in dense form :return: Sparse tensor ''' if len(args) == 2: return self.forward_with_indices_and_features(*args) return self.forward_with_dense_input(*args)
class CustomScatterDense(torch.autograd.Function): ''' Custom Scatter Dense autograd function ''' @staticmethod def symbolic(g, dense_inputs, attrs): ''' Symbolic method (static) for ScatterDense :param g:ONNX graph object :param dense_inputs: Dense inputs :param attrs: ScatterDense attributes :return: Added op to the graph object ''' save_attrs = {} for k, v in attrs.items(): if isinstance(v, str): save_attrs[k+"_s"] = v else: save_attrs[k+"_i"] = v return g.op("spconv::ScatterDense", dense_inputs, **save_attrs) @staticmethod def forward(ctx, dense_inputs, attrs): ''' forward method (static) for ScatterDense :param ctx: context object :param dense_inputs: Dense inputs :param attrs: ScatterDense attributes :return: Dense tensor ''' return dense_inputs
[docs]class ScatterDense(torch.nn.Module): ''' ScatterDense custom implementation ''' def __init__(self): super(ScatterDense, self).__init__() def forward(self, inputs): ''' Forward pass for ScatterDense :param inputs: Sparse Inputs :return: Dense tensor ''' if torch.jit.is_tracing(): attrs = { "format": "xyz", "input_spatial_shape": inputs.shape[2:], "output_shape": inputs.shape } return CustomScatterDense.apply(inputs, attrs) return inputs.dense() if isinstance(inputs, spconv.SparseConvTensor) else inputs
class ScatterND(torch.nn.Module): """ ScatterND op implementation """ def __init__(self, reduction: int = 0): super().__init__() self.reduction = reduction def forward(self, data: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: """ Forward-pass routine for ScatterND op """ output = torch.clone(data) if self.reduction == 1: f = torch.add elif self.reduction == 2: f = torch.mul else: f = None indices = indices.type(torch.int64) idx_list = indices.split(split_size=1, dim=-1) if f: output[idx_list] = f(output[idx_list], updates.reshape(output[idx_list].shape)) else: output[idx_list] = updates.reshape(output[idx_list].shape) return output class RoiAlign(torch.nn.Module): """ Custom module for ONNX RoiAlign """ def __init__(self, output_size: Union[int, Tuple[int, int]], spatial_scale: float, sampling_ratio: int): super().__init__() self.output_size = output_size self.spatial_scale = spatial_scale self.sampling_ratio = sampling_ratio def forward(self, inp: torch.Tensor, roi: torch.Tensor, batch_indices: torch.Tensor) -> torch.Tensor: """ Forward-pass routine for RoiAlign """ roi = torch.cat((torch.reshape(batch_indices, (batch_indices.shape[0], 1)), roi), dim=1) return torchvision.ops.roi_align(inp, roi, self.output_size, self.spatial_scale, self.sampling_ratio) class NonMaxSuppression(torch.nn.Module): """ Implementation of NMS Op in the form of nn.Module """ def __init__(self, iou_threshold: float, score_threshold: float, max_output_boxes_per_class: int): super().__init__() self.iou_threshold = iou_threshold self.score_threshold = score_threshold self.max_output_boxes_per_class = max_output_boxes_per_class @staticmethod def _modify_y1x1y2x2_to_x1y1x2y2(boxes): return boxes[:, torch.tensor([1, 0, 3, 2])] def forward(self, *args) -> torch.Tensor: """ Forward-pass routine for NMS op """ batches_boxes = args[0] batch_scores = args[1] res = [] for index, (boxes, scores) in enumerate(zip(batches_boxes, batch_scores)): for class_index, classes_score in enumerate(scores): nms_output = self.perform_nms_per_class(boxes, classes_score) res_per_class = [] for val in nms_output: res_per_class.append([index, class_index, val.detach()]) res_per_class = res_per_class[:self.max_output_boxes_per_class] res.extend(res_per_class) res = torch.tensor(res, dtype=torch.int64, device=args[0].device) out = torch.zeros(batch_scores.shape[0] * batch_scores.shape[1] * self.max_output_boxes_per_class, 3, dtype=torch.int64, device=args[0].device) indices = torch.arange(0, len(res) * 3, dtype=torch.int64, device=args[0].device) out.put_(indices, res) return out def perform_nms_per_class(self, boxes: torch.Tensor, classes_score: torch.Tensor) -> torch.Tensor: """ Performs NMS per class :param boxes: boxes on which NMS should be performed :param classes_score: corresponding class scores for the boxes :return: returns box indices filtered out by NMS """ filtered_score_ind = (classes_score > self.score_threshold).nonzero()[:, 0] filtered_boxes = boxes[filtered_score_ind] filtered_classes_score = classes_score[filtered_score_ind] res_ = torchvision.ops.nms(self._modify_y1x1y2x2_to_x1y1x2y2(filtered_boxes), filtered_classes_score, self.iou_threshold) return filtered_score_ind[res_] class GatherNd(torch.nn.Module): """ GatherNd op implementation""" # This class is created because Pytorch as of now doesn't have support for this OP def __init__(self, batch_dim: int): super().__init__() self.batch_dims = batch_dim def forward(self, data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ Forward-pass routine for GatherNd op """ if self.batch_dims == 0: return self._gather_nd(data, indices) data_rank = len(data.shape) assert indices.shape[-1] <= data_rank batch_dims_shape = [] batch_dims_size = 1 for i in range(self.batch_dims): batch_dims_shape.append(indices.shape[i]) batch_dims_size *= indices.shape[i] output_shape = ( batch_dims_shape + list(indices.shape)[self.batch_dims:-1] if (indices.shape[-1] == data_rank - self.batch_dims) else batch_dims_shape + list(indices.shape)[self.batch_dims:-1] + list(data.shape)[self.batch_dims + indices.shape[-1]:]) if torch.jit.is_tracing(): return torch.zeros(*output_shape, device=data.device) output_data_buffer = [] reshaped_indices = indices.reshape(batch_dims_size, -1, indices.shape[-1]) reshaped_data = data.reshape((batch_dims_size,) + data.shape[self.batch_dims:]) for batch_dim in range(reshaped_indices.shape[0]): for outer_dim in range(reshaped_indices.shape[1]): gather_index = tuple(reshaped_indices[batch_dim][outer_dim]) output_data_buffer.append(reshaped_data[(batch_dim, *gather_index)]) if output_data_buffer[0].dim() == 0: return torch.tensor(output_data_buffer, device=data.device).reshape(output_shape) return torch.cat(output_data_buffer).reshape(output_shape) @staticmethod def _gather_nd(data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ GatherNd operation for batch_dim=0 case :param data: Tensor to gather values :param indices: Index tensor to be used to gather values :return: Tensor after GatherNd operation """ data_rank, m = len(data.shape), indices.shape[-1] assert ( m <= data_rank ), f"m: {m} should be less than or equal to data_rank: {data_rank}" total_samples = indices.shape[:-1].numel() output_shape = indices.shape[:-1] + data.shape[m:] reshaped_indices = torch.split( tensor=indices.reshape(total_samples, m).transpose(0, 1), split_size_or_sections=1, ) return data[reshaped_indices].reshape(output_shape).contiguous() class ScatterElements(torch.nn.Module): """ ScatterElements op implementation """ def __init__(self, dim: int, reduce: str = None): super().__init__() self.dim = dim self.reduce = reduce def forward(self, x: Union[torch.Tensor, list], index: Union[torch.Tensor, list], src: Union[torch.Tensor, list]): """ Forward-pass routine for ScatterElements op """ if isinstance(index, list): index = torch.tensor(index, dtype=torch.int64) if isinstance(src, list): src = torch.tensor(src) if isinstance(x, list): x = torch.tensor(x, dtype=src.dtype) if self.reduce: if isinstance(src, torch.Tensor): return x.scatter_reduce_(self.dim, index, src, self.reduce) # If src is a single float value return x.scatter_(self.dim, index, src, reduce=self.reduce) return x.scatter_(self.dim, index, src) class OneHot(torch.nn.Module): """ Custom module for ONNX OneHot """ def __init__(self, num_classes: int, off_value: Union[int, float], on_value: Union[int, float]): super().__init__() self.num_classes = num_classes self.off_value = off_value self.on_value = on_value def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ Forward-pass routine for OneHot """ out = torch.nn.functional.one_hot(inputs, self.num_classes) if self.off_value != 0 or self.on_value != 1: out = out * (self.on_value - self.off_value) + self.off_value return out class Expand(torch.nn.Module): """Custom module for a Expand op""" def forward(self, tensor: torch.Tensor, *args) -> torch.Tensor: """ Forward-pass routine for Expand op """ return tensor.expand(*args) class DynamicLinear(torch.nn.Module): """Custom module for Dynamic Linear / FullyConnected Op""" # pylint:disable=no-self-use def forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> torch.Tensor: """ Forward-pass routine for Dynamic Linear Op """ return torch.nn.functional.linear(x, weight, bias) # TODO: Can be removed once AIMET supports torch >= 2.4 class RmsNorm(torch.nn.Module): """Custom module for RmsNorm""" def __init__(self, input_shape: list, axes: list, epsilon: float): super().__init__() self.epsilon = epsilon self.axes = axes normalized_shape = tuple(input_shape[i] for i in axes) self.weight = torch.nn.Parameter(torch.ones(normalized_shape)) self.bias = torch.nn.Parameter(torch.zeros(normalized_shape)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for RmsNorm """ input_dtype = x.dtype x = x.to(dtype=torch.float32, copy=True) squared_mean = torch.mean(x * x, dim=self.axes, keepdim=True) rms = torch.sqrt(squared_mean + self.epsilon) res = (torch.div(x, rms) * self.weight + self.bias).to(dtype=input_dtype) return res