aimet_torch.model_validator

AIMET provides a model validator utility to help check whether AIMET feature can be applied on a Pytorch model. The model validator currently checks for the following conditions:

  • No modules are reused

  • Operations have modules associated with them and are not defined as torch.nn.functional (excluding a set of known operations)

In this section, we present models failing the validation checks, and show how to run the model validator, as well as how to fix the models so the validation checks pass.

Example 1: Model with reused modules

We begin with the following model, which contains two relu modules sharing the same module instance.

class ModelWithReusedNodes(torch.nn.Module):
    """ Model that reuses a relu module. Expects input of shape (1, 3, 32, 32) """

    def __init__(self):
        super(ModelWithReusedNodes, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=2, stride=2, padding=2, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(8)
        self.relu1 = torch.nn.ReLU(inplace=True)
        self.linear = torch.nn.Linear(2592, 10)

    def forward(self, *inputs):
        x = self.conv1(inputs[0])
        x = self.relu1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

Import the model validator:

from aimet_torch.model_validator.model_validator import ModelValidator

Run the model validator on the model by passing in the model as well as model input:

def validate_example_model():

    # Load the model to validate
    model = ModelWithReusedNodes()

    # Output of ModelValidator.validate_model will be True if model is valid, False otherwise
    ModelValidator.validate_model(model, model_input=torch.rand(1, 3, 32, 32))

For each validation check run on the model, a logger print will appear:

Utils - INFO - Running validator check <function validate_for_reused_modules at 0x7f127685a598>

If the validation check finds any issues with the model, the log will contain information for how to resolve the model:

Utils - WARNING - The following modules are used more than once in the model: ['relu1']
AIMET features are not designed to work with reused modules. Please redefine your model to use distinct modules for
each instance.

Finally, at the end of the validation, any failing checks will be logged:

Utils - INFO - The following validator checks failed:
Utils - INFO -     <function validate_for_reused_modules at 0x7f127685a598>

In this case, the validate_for_reused_modules check informs that the relu1 module is being used multiple times in the model. We rewrite the model by defining a separate relu instance for each usage:

class ModelWithoutReusedNodes(torch.nn.Module):
    """ Model that is fixed to not reuse modules. Expects input of shape (1, 3, 32, 32) """

    def __init__(self):
        super(ModelWithoutReusedNodes, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=2, stride=2, padding=2, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(8)
        self.relu1 = torch.nn.ReLU(inplace=True)
        self.relu2 = torch.nn.ReLU(inplace=True)
        self.linear = torch.nn.Linear(2592, 10)

    def forward(self, *inputs):
        x = self.conv1(inputs[0])
        x = self.relu1(x)
        x = self.bn1(x)
        x = self.relu2(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

Now, after rerunning the model validator, all checks pass:

Utils - INFO - Running validator check <function validate_for_reused_modules at 0x7ff577373598>
Utils - INFO - Running validator check <function validate_for_missing_modules at 0x7ff5703eff28>
Utils - INFO - All validation checks passed.

Example 2: Model with functionals

We start with the following model, which uses a torch linear functional layer in the forward pass:

class ModelWithFunctionalLinear(torch.nn.Module):
    """ Model that uses a torch functional linear layer. Expects input of shape (1, 3, 32, 32) """

    def __init__(self):
        super(ModelWithFunctionalLinear, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=2, stride=2, padding=2, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(8)
        self.relu1 = torch.nn.ReLU(inplace=True)
        self.relu2 = torch.nn.ReLU(inplace=True)

    def forward(self, *inputs):
        x = self.conv1(inputs[0])
        x = self.relu1(x)
        x = self.bn1(x)
        x = self.relu2(x)
        x = x.view(x.size(0), -1)
        x = F.linear(x, torch.randn(10, 2592))
        return x

Running the model validator shows the validate_for_missing_modules check failing:

Utils - INFO - Running validator check <function validate_for_missing_modules at 0x7f9dd9bd90d0>
Utils - WARNING - Ops with missing modules: ['matmul_8']
This can be due to several reasons:
1. There is no mapping for the op in ConnectedGraph.op_type_map. Add a mapping for ConnectedGraph to recognize and
be able to map the op.
2. The op is defined as a functional in the forward function, instead of as a class module. Redefine the op as a
class module if possible. Else, check 3.
3. This op is one that cannot be defined as a class module, but has not been added to ConnectedGraph.functional_ops.
Add to continue.
Utils - INFO - The following validator checks failed:
Utils - INFO -      <function validate_for_missing_modules at 0x7f9dd9bd90d0>

The check has identified matmul_8 as an operation with a missing pytorch module. In this case, it is due to reason #2 in the log, in which the layer has been defined as a functional in the forward function. We rewrite the model by defining the layer as a module instead in order to resolve the issue.

class ModelWithoutFunctionalLinear(torch.nn.Module):
    """ Model that is fixed to use a linear module instead of functional. Expects input of shape (1, 3, 32, 32) """

    def __init__(self):
        super(ModelWithoutFunctionalLinear, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=2, stride=2, padding=2, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(8)
        self.relu1 = torch.nn.ReLU(inplace=True)
        self.relu2 = torch.nn.ReLU(inplace=True)
        self.linear = torch.nn.Linear(2592, 10)
        with torch.no_grad():
            self.linear.weight = torch.nn.Parameter(torch.randn(10, 2592))

    def forward(self, *inputs):
        x = self.conv1(inputs[0])
        x = self.relu1(x)
        x = self.bn1(x)
        x = self.relu2(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

API

class aimet_torch.model_validator.model_validator.ModelValidator[source]

ModelValidator object for validating that AIMET features can be applied on the Pytorch model.

static add_check(validation_check)[source]

Add a validation check function to be used for validating the model. Validation check functions must take the model, model inputs, and kwargs as inputs. The validation check must output True if the model passes the check, and False otherwise. :type validation_check: Callable :param validation_check: Validation check function for validating the model.

static validate_model(model, model_input, **kwargs)[source]

Validate the pytorch model by running all validation check functions and returning True if all pass, False otherwise. Keyword arguments can be used to pass specific arguments to particular validation checkers. Currently supported keyword arguments: layers_to_exclude: List of torch.nn.Modules to be excluded in the check for missing modules. These layers and all of their sublayers will not be flagged if they do not have a corresponding Pytorch module. :type model: Module :param model: Pytorch model to validate :type model_input: Union[Tensor, Tuple] :param model_input: Dummy input to the model :return True if pytorch model is valid, False otherwise

Return type:

bool