EXPERIMENTAL

Architecture Checker API

aimet_torch.arch_checker.arch_checker.ArchChecker.check_model_arch(model, dummy_input, result_dir=None)

Check each node in the model using checks in _node_check_dict. Record only the nodes and failed tests.

Parameters
  • model (Module) – Torch model to be checked.

  • dummy_input (Union[Tensor, Tuple]) – A dummy input to the model. Can be a Tensor or a Tuple of Tensors

Return arch_checker_report

{op.dotted_name_op: NodeErrorReportObject }

Return type

ArchCheckerReport

AIMET PyTorch Architecture Checker helps checks for sub-optimal model construct and provides potential option to update the model to be more performant. The architecture checker currently checks for the following conditions:

  • Convolution layers for optimal channel size.

  • Activation functions that are not performant.

  • Batch Normalization layer than cannot be folded.

  • Intermediate convolution layer in sequence of convolution layer having padding.

In this section, we present models failing the architecture checks, and show how to run the architecture checker.

Example 1: Model with not enough channels

We begin with the following model, which contains a convolution layer with channel less that 32.

class ModelWithNotEnoughChannels(torch.nn.Module):
    """ Model that prelu module. Expects input of shape (1, 3, 32, 32) """

    def __init__(self):
        super(ModelWithNotEnoughChannels, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 31, kernel_size=2, stride=2, padding=2, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(31)

    def forward(self, *inputs):
        x = self.conv1(inputs[0])
        x = self.bn1(x)
        return x

Import the architecture checker:


Run the checker on the model by passing in the model as well as the model input:

def example_check_for_number_of_conv_channels():

    model = ModelWithNotEnoughChannels()
    ArchChecker.check_model_arch(model, dummy_input=torch.rand(1, 3, 32, 32))

the convolution layer in the model has one fewer channel, the following logger print will appear:

Utils - INFO - Graph/Node: ModelWithNotEnoughChannels.conv1: Conv2d(3, 31, kernel_size=(2, 2), stride=(2, 2), padding=(2, 2), bias=False) fails check: {'_check_conv_channel_32_base', '_check_conv_channel_larger_than_32'}

A HTML file with the following content is generated.

HTML report content

Graph/Layer_name

Issue

Recommendation

ModelWithNotEnoughChannels.conv1

The channel size of input/output tensor of this convolution is smaller than 32

Try adjusting the channels to multiple of 32 to get better performance.

ModelWithNotEnoughChannels.conv1

The channel size of input/output tensor of this convolution is not a multiple of 32

Try adjusting the channels to multiple of 32 to get better performance.

Example 2: Model with non-performant activation

We begin with the following model, which contains a convolution layer with channel less that 32.

class ModelWithPrelu(torch.nn.Module):
    """ Model that prelu module. Expects input of shape (1, 3, 32, 32) """

    def __init__(self):
        super(ModelWithPrelu, self).__init__()
        self.conv1 = torch.nn.Conv2d(32, 32, kernel_size=2, stride=2, padding=2, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(32)
        self.prelu1 = torch.nn.PReLU()

    def forward(self, *inputs):
        x = self.conv1(inputs[0])
        x = self.bn1(x)
        x = self.prelu1(x)
        return x

Run the checker on the model by passing in the model as well as the model input:

def example_check_for_non_performant_activations():

    model = ModelWithPrelu()
    ArchChecker.check_model_arch(model, dummy_input=torch.rand(1, 32, 32, 32))

the PReLU layer in model is consider non-performant compared to ReLU, the following logger print will appear:

Utils - INFO - Graph/Node: ModelWithPrelu.prelu1: PReLU(num_parameters=1) fails check: {'_activation_checks'}

Example 3: Model with standalone batch normalization layer

We begin with the following model, which contains a convolution layer with channel less that 32.

class ModelWithNonfoldableBN(torch.nn.Module):
    """ Model that has non-foldable batch norm. """

    def __init__(self):
        super(ModelWithNonfoldableBN, self).__init__()
        self.conv1 = torch.nn.Conv2d(32, 32, kernel_size=2, stride=2, padding=2, bias=False)
        self.avg_pool1 = torch.nn.AvgPool2d(3, padding=1)
        self.bn1 = torch.nn.BatchNorm2d(32)

    def forward(self, *inputs):
        x = self.conv1(inputs[0])
        x = self.avg_pool1(x)
        x = self.bn1(x)
        return x

Run the checker on the model by passing in the model as well as the model input:

def example_check_for_standalone_bn():

    model = ModelWithNonfoldableBN()
    ArchChecker.check_model_arch(model, dummy_input=torch.rand(1, 32, 32, 32))

the AveragePool layer prevents the BatchNormalization layer to be folded with the Convolution layer, the following logger print will appear:

Utils - INFO - Graph/Node: ModelWithNonfoldableBN.bn1: BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) fails check: {'_check_batch_norm_fold'}