AIMET Torch SparseConvolution custom onnx export

This page gives an idea on how SparseConvolution based models can be used in AIMET. The SparseConvolution library that is used/supported here is Traveller’s SparseConvolution Library. Please note that,

  • Only SparseConvolution3D is supported as of now.

  • SpConv library (for cpu) is not very stable because it is found that the inference from the spconv module gives different outputs for the same inputs in the same runtime.

  • If there’s bias in the SpConv layer, please use GPU, as bias in SpConv is only supported in GPU.

Custom API for the spconv modules

The following api can be used to create a sparse tensor given indices and features in dense form

class aimet_torch.nn.modules.custom.SparseTensorWrapper[source]

Custom SparsetensorWrapper class for SparseConvTensor

Initializes internal Module state, shared by both nn.Module and ScriptModule.

The following api can be used to create a dense tensor given a sparse tensor

class aimet_torch.nn.modules.custom.ScatterDense[source]

ScatterDense custom implementation

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Code Example

Imports

import torch
import spconv.pytorch as spconv
import aimet_torch
from aimet_torch.quantsim import QuantizationSimModel, QuantScheme
from aimet_torch.pro.model_preparer import prepare_model

Create or load model with SpConv3D module(s)

class SpConvModel(torch.nn.Module):
    def __init__(self):
        super(SpConvModel, self).__init__()

        # "SparseTensorWrapper" needs to be used to convert a dense tensor to a sparse tensor
        self.spconv_tensor = aimet_torch.nn.modules.custom.SparseTensorWrapper()

        # First SparseConv3D layer
        self.spconv1 = spconv.SparseConv3d(in_channels=3, out_channels=9, kernel_size=2,
                                           bias=False)

        # Second SparseConv3D layer
        self.spconv2 = spconv.SparseConv3d(in_channels=9, out_channels=5, kernel_size=3, bias=False)

        # Normal Conv3D layer
        self.normal_conv3d = torch.nn.Conv3d(in_channels=5, out_channels=3, kernel_size=3, bias=True)

        # "ScatterDense" needs to be used to convert a sparse tensor to a dense tensor
        self.spconv_scatter_dense = aimet_torch.nn.modules.custom.ScatterDense()

        # Adding ReLU activation
        self.relu1 = torch.nn.ReLU()

    def forward(self, coords, voxels):
        '''
        Forward function for the test SpConvModel
        :param coords: Dense indices
        :param voxels: Dense features
        :return: SpConvModel output (dense tensor)
        '''

        # Convert dense indices and features to sparse tensor
        sp_tensor = self.spconv_tensor(coords, voxels)

        # Output from SparseConv3D layer 1
        sp_outputs1 = self.spconv1(sp_tensor)

        # Output from SparseConv3D layer 2
        sp_outputs2 = self.spconv2(sp_outputs1)

        # Convert Sparse tensor output to a dense tensor output
        sp_outputs2_dense = self.spconv_scatter_dense(sp_outputs2)

        # Output from Normal Conv3D layer
        sp_outputs = self.normal_conv3d(sp_outputs2_dense)

        # Output from ReLU
        sp_outputs_relu = self.relu1(sp_outputs)

        return sp_outputs_relu

Obtain model inputs

dense_tensor_sp_inputs = torch.randn(1, 3, 10, 10, 10) # generate a random NCDHW tensor
dense_tensor_sp_inputs = dense_tensor_sp_inputs.permute(0, 2, 3, 4, 1) # convert NCDHW to NDHWC

# Creating dense indices
indices = torch.stack(torch.meshgrid(torch.arange(dense_tensor_sp_inputs.shape[0]), torch.arange(dense_tensor_sp_inputs.shape[1]),
                                     torch.arange(dense_tensor_sp_inputs.shape[2]), torch.arange(dense_tensor_sp_inputs.shape[3]),
                                     indexing='ij'), dim=-1).reshape(-1, 4).int()

# Creating dense features
features = dense_tensor_sp_inputs.view(-1, dense_tensor_sp_inputs.shape[4])

Apply model preparer pro

    prepared_model = prepare_model(model, dummy_input=(indices, features), path=dir,
                                   onnx_export_args=dict(operator_export_type=
                                                         torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
                                                         opset_version=16),
                                   converter_args=['--input_dtype', "indices.1", "int32", '--input_dtype',
                                                   "features.1", "float32", '--expand_sparse_op_structure',
                                                   '--preserve_io', 'datatype', 'indices.1'])

Apply QuantSim (or any other AIMET features)

qsim = QuantizationSimModel(prepared_model, dummy_input=(indices, features),
                            quant_scheme=QuantScheme.post_training_tf)

Compute Encodings

qsim.compute_encodings(dummy_forward_pass, (indices, features))

QuantSim Exports

    qsim.export(dir, "exported_sp_conv_model", dummy_input=(indices, features),
                onnx_export_args=dict(operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK))