# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Custom modules for functional operations defined under torch and torch.nn.functional packages """
from typing import Callable, Any, Tuple, Union, List
import torchvision
import torch
import torch.nn
import spconv.pytorch as spconv
# pylint: disable=no-self-use
def forward_function_wrapper(functional: Callable) -> Any:
"""
Wrapper function returning forward method for given functional operation.
:param functional: torch.nn.functional
:return: forward method
"""
def forward(self, *args, **kwargs) -> Any: # pylint: disable=unused-argument
"""
Forward-pass routine for the functional operation.
"""
return functional(*args, **kwargs)
return forward
def create_wrapper_module(class_name: str, functional: Callable) -> Callable:
"""
Dynamically create wrapper module for a functional operation.
:param class_name: Name of the class.
:param functional: Functional operation.
:return: Module.
"""
wrapped_module = type(class_name, (torch.nn.Module,), {'forward': forward_function_wrapper(functional)})
return wrapped_module
# modules for functional operations under torch package
Subtract = create_wrapper_module('Subtract', torch.sub)
Divide = create_wrapper_module('Divide', torch.div)
FloorDivide = create_wrapper_module('FloorDivide', torch.floor_divide)
MatMul = create_wrapper_module('MatMul', torch.matmul)
Norm = create_wrapper_module('Norm', torch.norm)
Exponential = create_wrapper_module('Exponential', torch.exp)
Erf = create_wrapper_module('Erf', torch.erf)
Sqrt = create_wrapper_module('Sqrt', torch.sqrt)
Maximum = create_wrapper_module('Maximum', torch.maximum)
Max = create_wrapper_module('Max', torch.max) # NOTE: Not elementwise
AMax = create_wrapper_module('AMax', torch.amax)
Minimum = create_wrapper_module('Minimum', torch.minimum)
Min = create_wrapper_module('Min', torch.min) # NOTE: Not elementwise
AMin = create_wrapper_module('AMin', torch.amin)
Where = create_wrapper_module('Where', torch.where)
Greater = create_wrapper_module('Greater', torch.gt)
Less = create_wrapper_module('Less', torch.lt)
GreaterEqual = create_wrapper_module('GreaterEqual', torch.ge)
LessEqual = create_wrapper_module('LessEqual', torch.le)
NotEqual = create_wrapper_module('NotEqual', torch.ne)
Equal = create_wrapper_module('Equal', torch.eq)
Bmm = create_wrapper_module('Bmm', torch.bmm)
CumSum = create_wrapper_module('CumSum', torch.cumsum)
MaskedFill = create_wrapper_module('MaskedFill', torch.Tensor.masked_fill_)
Mean = create_wrapper_module('Mean', torch.mean)
Sum = create_wrapper_module('Sum', torch.sum)
Prod = create_wrapper_module('Prod', torch.prod)
Log = create_wrapper_module('Log', torch.log)
Abs = create_wrapper_module('Abs', torch.abs)
Neg = create_wrapper_module('Neg', torch.neg)
Argmin = create_wrapper_module('Argmin', torch.argmin)
Argmax = create_wrapper_module('Argmax', torch.argmax)
ElementwiseCeil = create_wrapper_module('ElementwiseCeil', torch.ceil)
ElementwiseFloor = create_wrapper_module('ElementwiseFloor', torch.floor)
Sin = create_wrapper_module('Sin', torch.sin)
Cos = create_wrapper_module('Cos', torch.cos)
Asin = create_wrapper_module('Asin', torch.asin)
Atan = create_wrapper_module('Atan', torch.atan)
Round = create_wrapper_module('Round', torch.round)
Gather = create_wrapper_module('Gather', torch.gather)
LogicalOr = create_wrapper_module('LogicalOr', torch.logical_or)
LogicalAnd = create_wrapper_module('LogicalAnd', torch.logical_and)
LogicalNot = create_wrapper_module('LogicalNot', torch.logical_not)
Split = create_wrapper_module('Split', torch.split)
Reshape = create_wrapper_module('Reshape', torch.reshape)
Permute = create_wrapper_module('Permute', torch.permute)
Remainder = create_wrapper_module('Remainder', torch.remainder)
IndexSelect = create_wrapper_module('IndexSelect', torch.index_select)
Fmod = create_wrapper_module('Fmod', torch.fmod)
NonZero = create_wrapper_module('NonZero', torch.nonzero)
TopK = create_wrapper_module('TopK', torch.topk)
Shape = create_wrapper_module('Shape', torch.Tensor.size)
Tile = create_wrapper_module('Tile', torch.tile)
ElementwiseUnarySign = create_wrapper_module('ElementwiseUnarySign', torch.sign)
Baddbmm = create_wrapper_module('Baddbmm', torch.baddbmm)
Addmm = create_wrapper_module('Addmm', torch.addmm)
RSqrt = create_wrapper_module('RSqrt', torch.rsqrt)
Square = create_wrapper_module('Square', torch.square)
Select = create_wrapper_module('Select', torch.select)
# modules for functional operations defined under torch.nn.functional package
Interpolate = create_wrapper_module('Interpolate', torch.nn.functional.interpolate)
MaxPool2d = create_wrapper_module('MaxPool2d', torch.nn.functional.max_pool2d)
AdaptiveAvgPool2d = create_wrapper_module('AdaptiveAvgPool2d', torch.nn.functional.adaptive_avg_pool2d)
AvgPool2d = create_wrapper_module('AvgPool2d', torch.nn.functional.avg_pool2d)
BatchNorm = create_wrapper_module('BatchNorm', torch.nn.functional.batch_norm)
GroupNorm = create_wrapper_module('GroupNorm', torch.nn.functional.group_norm)
Normalize = create_wrapper_module('Normalize', torch.nn.functional.normalize)
Pad = create_wrapper_module('Pad', torch.nn.functional.pad)
GridSample = create_wrapper_module('GridSample', torch.nn.functional.grid_sample)
# following modules are for overloaded operators like + and *,
# which can operate other than torch.Tensor datatype.
class Add(torch.nn.Module):
""" Add module for a functional add"""
# pylint:disable=arguments-differ
def forward(self, x: Any, y: Any) -> Any:
"""
Forward-pass routine for add op
"""
if isinstance(x, torch.Tensor) or isinstance(y, torch.Tensor):
out = torch.add(x, y)
else:
out = x + y
return out
class Multiply(torch.nn.Module):
""" Multiply module for a functional multiply"""
# pylint:disable=arguments-differ
def forward(self, x: Any, y: Any) -> Any:
"""
Forward-pass routine for multiply op
"""
if isinstance(x, torch.Tensor) or isinstance(y, torch.Tensor):
out = torch.mul(x, y)
else:
out = x * y
return out
# modules for functional requiring special handling
class Concat(torch.nn.Module):
""" Concat module for a functional concat"""
def __init__(self, axis: int = 0):
super().__init__()
self._axis = axis
# pylint:disable=arguments-differ
def forward(self, *x) -> torch.Tensor:
"""
Forward-pass routine for cat op
"""
return torch.cat(x, dim=self._axis)
class DynamicConv2d(torch.nn.Module):
""" Conv2d module for a functional conv2d"""
def __init__(self, stride=1, padding=0, dilation=1, groups=1):
super().__init__()
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
def forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> torch.Tensor:
"""
Forward-pass routine for conv2d op
"""
return torch.nn.functional.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class Pow(torch.nn.Module):
""" Pow module for a functional pow """
# pylint:disable=arguments-differ
def forward(self, x: Any, y: Any) -> Any:
"""
Forward-pass routine for Pow op
"""
return x ** y
class CustomSiLU(torch.nn.Module):
""" SiLU as Sigmoid + mul """
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()
self.mul = Multiply()
def forward(self, x: torch.Tensor) -> Any:
"""
Forward-pass routine for custom SiLU
"""
return self.mul(x, self.sigmoid(x))
class StridedSlice(torch.nn.Module):
"""Custom module for a functional slice"""
def forward(self, *args) -> torch.Tensor:
"""
Forward-pass routine for StridedSlice op
"""
tensor, slice_ranges = args
slice_params = []
for slice_range in slice_ranges:
slice_params.append(slice(*slice_range))
return tensor[slice_params]
class ChannelShuffle(torch.nn.Module):
"""Custom module for a ChannelShuffle op"""
def __init__(self, groups):
super().__init__()
self.groups = groups
def forward(self, *args) -> torch.Tensor:
"""
Forward-pass routine for ChannelShuffle op
"""
tensor = args[0]
n, c, h, w = tensor.shape
return tensor.view(n, self.groups, c // self.groups, h, w).transpose(1, 2).contiguous().view(n, -1, h, w)
class Cast(torch.nn.Module):
""" Cast module for a functional cast"""
def __init__(self, dtype):
super().__init__()
self.dtype = dtype
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward-pass routine for cast op
"""
return x.type(self.dtype)
class CustomGather(torch.nn.Module):
""" Custom module for ONNX Gather """
def forward(self, data: torch.Tensor, indices: torch.Tensor, axis: int = 0) -> torch.Tensor:
"""
Forward-pass routine for ONNX Gather op
"""
target_shape = data.shape[:axis] + indices.shape + data.shape[axis + 1:]
indices = (indices < 0).to(indices.dtype) * data.shape[axis] + indices
return torch.index_select(data, axis, indices.flatten()).reshape(target_shape)
class DepthToSpaceCRDMode(torch.nn.Module):
""" Depthtospace op implementation in CRD mode """
def __init__(self, block_size: List):
super().__init__()
self.block_size_h = block_size[0]
self.block_size_w = block_size[1]
def forward(self, x: torch.Tensor) -> Any:
"""
Forward-pass routine for DepthToSpace op in CRD mode
"""
b, c, h, w = x.shape
tmp = torch.reshape(x, (b, c // (self.block_size_h * self.block_size_w), self.block_size_h, self.block_size_w, h, w))
tmp = torch.permute(tmp, (0, 1, 4, 2, 5, 3))
out = torch.reshape(tmp, (b, c // (self.block_size_h * self.block_size_w), h * self.block_size_h, w * self.block_size_w))
return out
class DepthToSpaceDCRMode(torch.nn.Module):
""" Depthtospace op implementation in DCR mode """
# This class is created because Pytorch as of now doesn't have option
# to run DCR mode in PixelShuffle op.
def __init__(self, block_size: int):
super().__init__()
self.block_size = block_size
def forward(self, x: torch.Tensor) -> Any:
"""
Forward-pass routine for DepthToSpace op in DCR mode
"""
b, c, h, w = x.shape
blocksize = self.block_size
tmp = torch.reshape(x, (b, blocksize, blocksize, c // (blocksize**2), h, w))
tmp = torch.permute(tmp, (0, 3, 4, 1, 5, 2))
out = torch.reshape(tmp, (b, c // (blocksize**2), h * blocksize, w * blocksize))
return out
# pylint: disable=abstract-method, arguments-differ, unused-argument
class CustomSparseConv3d(torch.autograd.Function):
'''
Custom Sparse Conv3d autograd function
'''
@staticmethod
def symbolic(g, dense_inputs, weight, bias, all_sp_conv_attrs):
'''
Symbolic method (static) for Custom sparse Conv3d
:param g: ONNX graph object
:param dense_inputs: Dense inputs
:param weight: weight value
:param bias: bias value
:param all_sp_conv_attrs: spconv attributes
:return: Added op to the graph object
'''
attrs = {}
for k, v in all_sp_conv_attrs.items():
if v:
if isinstance(v, str):
attrs[k+"_s"] = v
else:
attrs[k+"_i"] = v
if bias:
return g.op("spconv::SparseConvolution", dense_inputs, weight, bias, **attrs)
return g.op("spconv::SparseConvolution", dense_inputs, weight, **attrs)
@staticmethod
def forward(ctx, dense_inputs, weight, bias, all_sp_conv_attrs):
'''
forward method (static) for Custom sparse Conv3d
:param ctx: context object
:param dense_inputs: Dense inputs
:param weight: weight value
:param bias: bias value
:param all_sp_conv_attrs: spconv attributes
:return: Dense tensor
'''
device = weight.device
dense_inputs = dense_inputs.to(device)
sp_conv_attrs = dict()
ignore = ['ndim', 'output_bound', 'input_spatial_shape', 'activation', 'subm', 'batch_size', 'spatial_shape',
'input_shape', 'inverse', 'transposed', 'rulebook', 'output_shape', 'output_spatial_shape',
'output_padding']
for k, v in all_sp_conv_attrs.items():
if k in ignore:
continue
sp_conv_attrs[k] = v
sp_conv_attrs['bias'] = sp_conv_attrs.get("bias", False)
conv3d = torch.nn.Conv3d(**sp_conv_attrs)
with torch.no_grad():
conv3d.weight.copy_(weight.detach().permute(0, 4, 1, 2, 3))
if sp_conv_attrs['bias']:
conv3d.bias.copy_(bias.detach())
conv3d = conv3d.to(device)
out = conv3d(dense_inputs)
return out
class CustomSparseConv3d_WithIndicesFeatures(torch.autograd.Function):
'''
Custom Sparse Conv3d (with indices and features as inputs) autograd function
'''
@staticmethod
def symbolic(g, indices, features, weight, bias, all_sp_conv_attrs):
'''
Symbolic method (static) for Custom sparse Conv3d (with indices and features as inputs)
:param g: ONNX graph object
:param indices: Indices input
:param features: Features input
:param weight: weight value
:param bias: bias value
:param all_sp_conv_attrs: spconv attributes
:return: Added op to the graph object
'''
remove = ['spatial_shape', 'batch_size']
attrs = {}
for k, v in all_sp_conv_attrs.items():
if k not in remove and v:
if isinstance(v, str):
attrs[k+"_s"] = v
else:
attrs[k+"_i"] = v
if bias:
return g.op("spconv::SparseConvolution", indices, features, weight, bias, **attrs)
return g.op("spconv::SparseConvolution", indices, features, weight, **attrs)
@staticmethod
def forward(ctx, indices, features, weight, bias, all_sp_conv_attrs):
'''
forward method (static) for Custom sparse Conv3d (with indices and features as inputs)
:param ctx: context object
:param indices: Indices input
:param features: Features input
:param weight: weight value
:param bias: bias value
:param all_sp_conv_attrs: spconv attributes
:return: Dense tensor
'''
device = weight.device
indices = indices.to(device)
features = features.to(device)
sp_conv_attrs = dict()
ignore = ['ndim', 'output_bound', 'input_spatial_shape', 'activation', 'subm', 'batch_size', 'spatial_shape',
'input_shape', 'inverse', 'transposed', 'rulebook', 'output_shape', 'output_spatial_shape',
'output_padding']
for k, v in all_sp_conv_attrs.items():
if k in ignore:
continue
sp_conv_attrs[k] = v
sp_conv_attrs['bias'] = sp_conv_attrs.get("bias", False)
conv3d = torch.nn.Conv3d(**sp_conv_attrs)
with torch.no_grad():
conv3d.weight.copy_(weight.detach().permute(0, 4, 1, 2, 3))
if sp_conv_attrs['bias']:
conv3d.bias.copy_(bias.detach())
conv3d = conv3d.to(device)
dense_inputs = features.reshape(all_sp_conv_attrs['batch_size'], features.shape[1],
*all_sp_conv_attrs['spatial_shape'])
dense_inputs = dense_inputs.to(device)
out = conv3d(dense_inputs)
return out
# pylint: disable=too-many-arguments, super-with-arguments
class CustomSparseConv3DLayer(torch.nn.Module):
'''
SparseConv3D op implementation
'''
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
super(CustomSparseConv3DLayer, self).__init__()
activation = "None" #"ReLU"
self.sp_conv_3d = spconv.SparseConv3d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, bias=bias, stride=stride, padding=padding,
dilation=dilation, groups=1, algo=spconv.ConvAlgo.Native) # doesn't support groups as of now
self.bias_available = bias
if not bias:
with torch.no_grad():
self.sp_conv_3d.bias = torch.nn.Parameter(torch.zeros(out_channels))
self.conv_attrs_dict = dict(in_channels=self.sp_conv_3d.in_channels,
out_channels=self.sp_conv_3d.out_channels,
kernel_size=self.sp_conv_3d.kernel_size,
stride=self.sp_conv_3d.stride,
padding=self.sp_conv_3d.padding,
dilation=self.sp_conv_3d.dilation,
subm=int(self.sp_conv_3d.subm),
ndim=self.sp_conv_3d.ndim,
output_bound=20000,
activation=activation,
groups=groups)
def forward_with_indices_features(self, indices, features):
'''
forward with indices and features as inputs
:param indices: Indices input
:param features: Features input
:return: Dense tensor output
'''
spatial_shape = [indices[:, 1].max().item()+1, indices[:, 2].max().item()+1,
indices[:, 3].max().item()+1]
batch_size = indices[:, 0].max().item()+1
if torch.jit.is_tracing():
self.conv_attrs_dict['spatial_shape'] = spatial_shape
self.conv_attrs_dict['batch_size'] = batch_size
self.conv_attrs_dict['input_spatial_shape'] = spatial_shape
self.conv_attrs_dict['output_bound'] = features.shape[0]
self.conv_attrs_dict['input_shape'] = features.shape
self.conv_attrs_dict['rulebook'] = "subm" + str(self.conv_attrs_dict['subm'])
self.conv_attrs_dict['transposed'] = 0
self.conv_attrs_dict['inverse'] = 0
self.conv_attrs_dict = dict(sorted(self.conv_attrs_dict.items(), key=lambda x: (x[0], x[1])))
return CustomSparseConv3d_WithIndicesFeatures.apply(indices, features, self.sp_conv_3d.weight,
self.sp_conv_3d.bias, self.conv_attrs_dict)
sp_tensor = spconv.SparseConvTensor(features=features, indices=indices, spatial_shape=spatial_shape,
batch_size=batch_size)
saved_bias_zero = self.sp_conv_3d.bias
if not self.bias_available:
self.sp_conv_3d.bias = None
sp_conv_outs = self.sp_conv_3d(sp_tensor)
dense_outs = sp_conv_outs.dense()
if not self.bias_available:
self.sp_conv_3d.bias = saved_bias_zero
return dense_outs
def forward_with_dense_input(self, dense_inp):
"""
Forward-pass routine for SparseConv3D op
"""
if isinstance(dense_inp, (tuple, list)) and len(dense_inp) == 2:
return self.forward_with_indices_features(*tuple(dense_inp))
if isinstance(dense_inp, spconv.SparseConvTensor):
dense_inp = dense_inp.dense(channels_first=True)
if torch.jit.is_tracing():
self.conv_attrs_dict['input_spatial_shape'] = dense_inp.shape[2:]
self.conv_attrs_dict['spatial_shape'] = dense_inp.shape[2:]
self.conv_attrs_dict['batch_size'] = dense_inp.shape[0]
self.conv_attrs_dict['output_bound'] = dense_inp.shape[0] * dense_inp.shape[2] * dense_inp.shape[3] * \
dense_inp.shape[4]
self.conv_attrs_dict['input_shape'] = [self.conv_attrs_dict['output_bound'], dense_inp.shape[1]]
self.conv_attrs_dict['rulebook'] = "subm" + str(self.conv_attrs_dict['subm'])
self.conv_attrs_dict['transposed'] = 0
self.conv_attrs_dict['inverse'] = 0
self.conv_attrs_dict = dict(sorted(self.conv_attrs_dict.items(), key=lambda x: (x[0], x[1])))
return CustomSparseConv3d.apply(dense_inp, self.sp_conv_3d.weight, self.sp_conv_3d.bias, self.conv_attrs_dict)
# Dense to Sparse Conversion
dense_inp = dense_inp.permute(0, 2, 3, 4, 1) # N D H W C
indices = torch.stack(torch.meshgrid(torch.arange(dense_inp.shape[0]), torch.arange(dense_inp.shape[1]),
torch.arange(dense_inp.shape[2]), torch.arange(dense_inp.shape[3]),
indexing='ij'), dim=-1).reshape(-1, 4).int()
features = dense_inp.reshape(-1, dense_inp.shape[4])
spatial_shape = dense_inp.shape[1:-1]
batch_size = dense_inp.shape[0]
sp_tensor = spconv.SparseConvTensor(features=features, indices=indices, spatial_shape=spatial_shape,
batch_size=batch_size)
saved_bias_zero = self.sp_conv_3d.bias
if not self.bias_available:
self.sp_conv_3d.bias = None
sp_conv_outs = self.sp_conv_3d(sp_tensor)
dense_outs = sp_conv_outs.dense()
if not self.bias_available:
self.sp_conv_3d.bias = saved_bias_zero
return dense_outs
def forward(self, *args):
'''
Forward pass for Custom SparseConv3d layer
:param args: Either one dense input of format NCDHW or two inputs (indices, features) both in dense form
:return: Dense tensor
'''
if len(args) == 2:
return self.forward_with_indices_features(*args)
return self.forward_with_dense_input(*args)
# pylint: disable=useless-super-delegation
[docs]class SparseTensorWrapper(torch.nn.Module):
'''
Custom SparsetensorWrapper class for SparseConvTensor
'''
def __init__(self):
super(SparseTensorWrapper, self).__init__()
def forward_with_indices_and_features(self, coords, voxels):
'''
forward pass with indices and features as inputs
:param coords: Indices input
:param voxels: Features input
:return: Sparse tensor
'''
# dense_inp is expected to be in N C D H W format
if torch.jit.is_tracing():
return coords, voxels
spatial_shape = [coords[:, 1].max()+1, coords[:, 2].max()+1, coords[:, 3].max()+1]
return spconv.SparseConvTensor(
features=voxels,
indices=coords,
spatial_shape=spatial_shape,
batch_size=coords[:, 0].max()+1
)
def forward_with_dense_input(self, dense_inp):
'''
forward pass with single dense input (NCDHW format)
:param dense_inp: Dense input
:return: Sparse tensor
'''
if isinstance(dense_inp, tuple) and len(dense_inp) == 2:
return self.forward_with_indices_and_features(*dense_inp)
# dense_inp is expected to be in N C D H W format
if torch.jit.is_tracing():
return dense_inp
dense_inp = dense_inp.permute(0, 2, 3, 4, 1)
# Considering all indices as dense
indices = torch.stack(torch.meshgrid(torch.arange(dense_inp.shape[0]), torch.arange(dense_inp.shape[1]),
torch.arange(dense_inp.shape[2]), torch.arange(dense_inp.shape[3]),
indexing='ij'), dim=-1).reshape(-1, 4).int()
features = dense_inp.reshape(-1, dense_inp.shape[4])
spatial_shape = dense_inp.shape[1:-1]
return spconv.SparseConvTensor(
features=features,
indices=indices,
spatial_shape=spatial_shape,
batch_size=dense_inp.shape[0]
)
def forward(self, *args):
'''
Forward pass for SparseConvTensor's custom implementation
:param args: Either one dense input of format NCDHW or two inputs (indices, features) both in dense form
:return: Sparse tensor
'''
if len(args) == 2:
return self.forward_with_indices_and_features(*args)
return self.forward_with_dense_input(*args)
class CustomScatterDense(torch.autograd.Function):
'''
Custom Scatter Dense autograd function
'''
@staticmethod
def symbolic(g, dense_inputs, attrs):
'''
Symbolic method (static) for ScatterDense
:param g:ONNX graph object
:param dense_inputs: Dense inputs
:param attrs: ScatterDense attributes
:return: Added op to the graph object
'''
save_attrs = {}
for k, v in attrs.items():
if isinstance(v, str):
save_attrs[k+"_s"] = v
else:
save_attrs[k+"_i"] = v
return g.op("spconv::ScatterDense", dense_inputs, **save_attrs)
@staticmethod
def forward(ctx, dense_inputs, attrs):
'''
forward method (static) for ScatterDense
:param ctx: context object
:param dense_inputs: Dense inputs
:param attrs: ScatterDense attributes
:return: Dense tensor
'''
return dense_inputs
[docs]class ScatterDense(torch.nn.Module):
'''
ScatterDense custom implementation
'''
def __init__(self):
super(ScatterDense, self).__init__()
def forward(self, inputs):
'''
Forward pass for ScatterDense
:param inputs: Sparse Inputs
:return: Dense tensor
'''
if torch.jit.is_tracing():
attrs = {
"format": "xyz",
"input_spatial_shape": inputs.shape[2:],
"output_shape": inputs.shape
}
return CustomScatterDense.apply(inputs, attrs)
return inputs.dense() if isinstance(inputs, spconv.SparseConvTensor) else inputs
class ScatterND(torch.nn.Module):
""" ScatterND op implementation """
def __init__(self, reduction: int = 0):
super().__init__()
self.reduction = reduction
def forward(self, data: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor) -> torch.Tensor:
"""
Forward-pass routine for ScatterND op
"""
output = torch.clone(data)
if self.reduction == 1:
f = torch.add
elif self.reduction == 2:
f = torch.mul
else:
f = None
indices = indices.type(torch.int64)
idx_list = indices.split(split_size=1, dim=-1)
if f:
output[idx_list] = f(output[idx_list], updates.reshape(output[idx_list].shape))
else:
output[idx_list] = updates.reshape(output[idx_list].shape)
return output
class RoiAlign(torch.nn.Module):
""" Custom module for ONNX RoiAlign """
def __init__(self, output_size: Union[int, Tuple[int, int]], spatial_scale: float, sampling_ratio: int):
super().__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
def forward(self, inp: torch.Tensor, roi: torch.Tensor, batch_indices: torch.Tensor) -> torch.Tensor:
"""
Forward-pass routine for RoiAlign
"""
roi = torch.cat((torch.reshape(batch_indices, (batch_indices.shape[0], 1)), roi), dim=1)
return torchvision.ops.roi_align(inp, roi, self.output_size, self.spatial_scale, self.sampling_ratio)
class NonMaxSuppression(torch.nn.Module):
"""
Implementation of NMS Op in the form of nn.Module
"""
def __init__(self, iou_threshold: float, score_threshold: float, max_output_boxes_per_class: int):
super().__init__()
self.iou_threshold = iou_threshold
self.score_threshold = score_threshold
self.max_output_boxes_per_class = max_output_boxes_per_class
@staticmethod
def _modify_y1x1y2x2_to_x1y1x2y2(boxes):
return boxes[:, torch.tensor([1, 0, 3, 2])]
def forward(self, *args) -> torch.Tensor:
"""
Forward-pass routine for NMS op
"""
batches_boxes = args[0]
batch_scores = args[1]
res = []
for index, (boxes, scores) in enumerate(zip(batches_boxes, batch_scores)):
for class_index, classes_score in enumerate(scores):
nms_output = self.perform_nms_per_class(boxes, classes_score)
res_per_class = []
for val in nms_output:
res_per_class.append([index, class_index, val.detach()])
res_per_class = res_per_class[:self.max_output_boxes_per_class]
res.extend(res_per_class)
res = torch.tensor(res, dtype=torch.int64, device=args[0].device)
out = torch.zeros(batch_scores.shape[0] * batch_scores.shape[1] * self.max_output_boxes_per_class, 3,
dtype=torch.int64, device=args[0].device)
indices = torch.arange(0, len(res) * 3, dtype=torch.int64, device=args[0].device)
out.put_(indices, res)
return out
def perform_nms_per_class(self, boxes: torch.Tensor, classes_score: torch.Tensor) -> torch.Tensor:
"""
Performs NMS per class
:param boxes: boxes on which NMS should be performed
:param classes_score: corresponding class scores for the boxes
:return: returns box indices filtered out by NMS
"""
filtered_score_ind = (classes_score > self.score_threshold).nonzero()[:, 0]
filtered_boxes = boxes[filtered_score_ind]
filtered_classes_score = classes_score[filtered_score_ind]
res_ = torchvision.ops.nms(self._modify_y1x1y2x2_to_x1y1x2y2(filtered_boxes), filtered_classes_score, self.iou_threshold)
return filtered_score_ind[res_]
class GatherNd(torch.nn.Module):
""" GatherNd op implementation"""
# This class is created because Pytorch as of now doesn't have support for this OP
def __init__(self, batch_dim: int):
super().__init__()
self.batch_dims = batch_dim
def forward(self, data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
Forward-pass routine for GatherNd op
"""
if self.batch_dims == 0:
return self._gather_nd(data, indices)
data_rank = len(data.shape)
assert indices.shape[-1] <= data_rank
batch_dims_shape = []
batch_dims_size = 1
for i in range(self.batch_dims):
batch_dims_shape.append(indices.shape[i])
batch_dims_size *= indices.shape[i]
output_shape = (
batch_dims_shape + list(indices.shape)[self.batch_dims:-1]
if (indices.shape[-1] == data_rank - self.batch_dims)
else batch_dims_shape + list(indices.shape)[self.batch_dims:-1] + list(data.shape)[self.batch_dims + indices.shape[-1]:])
if torch.jit.is_tracing():
return torch.zeros(*output_shape, device=data.device)
output_data_buffer = []
reshaped_indices = indices.reshape(batch_dims_size, -1, indices.shape[-1])
reshaped_data = data.reshape((batch_dims_size,) + data.shape[self.batch_dims:])
for batch_dim in range(reshaped_indices.shape[0]):
for outer_dim in range(reshaped_indices.shape[1]):
gather_index = tuple(reshaped_indices[batch_dim][outer_dim])
output_data_buffer.append(reshaped_data[(batch_dim, *gather_index)])
if output_data_buffer[0].dim() == 0:
return torch.tensor(output_data_buffer, device=data.device).reshape(output_shape)
return torch.cat(output_data_buffer).reshape(output_shape)
@staticmethod
def _gather_nd(data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
GatherNd operation for batch_dim=0 case
:param data: Tensor to gather values
:param indices: Index tensor to be used to gather values
:return: Tensor after GatherNd operation
"""
data_rank, m = len(data.shape), indices.shape[-1]
assert (
m <= data_rank
), f"m: {m} should be less than or equal to data_rank: {data_rank}"
total_samples = indices.shape[:-1].numel()
output_shape = indices.shape[:-1] + data.shape[m:]
reshaped_indices = torch.split(
tensor=indices.reshape(total_samples, m).transpose(0, 1),
split_size_or_sections=1,
)
return data[reshaped_indices].reshape(output_shape).contiguous()
class ScatterElements(torch.nn.Module):
""" ScatterElements op implementation """
def __init__(self, dim: int, reduce: str = None):
super().__init__()
self.dim = dim
self.reduce = reduce
def forward(self, x: Union[torch.Tensor, list],
index: Union[torch.Tensor, list],
src: Union[torch.Tensor, list]):
"""
Forward-pass routine for ScatterElements op
"""
if isinstance(index, list):
index = torch.tensor(index, dtype=torch.int64)
if isinstance(src, list):
src = torch.tensor(src)
if isinstance(x, list):
x = torch.tensor(x, dtype=src.dtype)
if self.reduce:
if isinstance(src, torch.Tensor):
return x.scatter_reduce_(self.dim, index, src, self.reduce)
# If src is a single float value
return x.scatter_(self.dim, index, src, reduce=self.reduce)
return x.scatter_(self.dim, index, src)
class OneHot(torch.nn.Module):
""" Custom module for ONNX OneHot """
def __init__(self, num_classes: int, off_value: Union[int, float], on_value: Union[int, float]):
super().__init__()
self.num_classes = num_classes
self.off_value = off_value
self.on_value = on_value
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Forward-pass routine for OneHot
"""
out = torch.nn.functional.one_hot(inputs, self.num_classes)
if self.off_value != 0 or self.on_value != 1:
out = out * (self.on_value - self.off_value) + self.off_value
return out
class Expand(torch.nn.Module):
"""Custom module for a Expand op"""
def forward(self, tensor: torch.Tensor, *args) -> torch.Tensor:
"""
Forward-pass routine for Expand op
"""
return tensor.expand(*args)
class DynamicLinear(torch.nn.Module):
"""Custom module for Dynamic Linear / FullyConnected Op"""
# pylint:disable=no-self-use
def forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> torch.Tensor:
"""
Forward-pass routine for Dynamic Linear Op
"""
return torch.nn.functional.linear(x, weight, bias)
# TODO: Can be removed once AIMET supports torch >= 2.4
class RmsNorm(torch.nn.Module):
"""Custom module for RmsNorm"""
def __init__(self, input_shape: list, axes: list, epsilon: float):
super().__init__()
self.epsilon = epsilon
self.axes = axes
normalized_shape = tuple(input_shape[i] for i in axes)
self.weight = torch.nn.Parameter(torch.ones(normalized_shape))
self.bias = torch.nn.Parameter(torch.zeros(normalized_shape))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for RmsNorm
"""
squared_mean = torch.mean(x * x, dim=self.axes, keepdim=True)
rms = torch.sqrt(squared_mean + self.epsilon)
res = torch.div(x, rms) * self.weight + self.bias
return res