Model Preparer API

AIMET PyTorch ModelPreparer API uses new graph transformation feature available in PyTorch 1.9+ version and automates model definition changes required by user. For example, it changes functionals defined in forward pass to torch.nn.Module type modules for activation and elementwise functions. Also, when torch.nn.Module type modules are reused, it unrolls into independent modules.

Users are strongly encouraged to use AIMET PyTorch ModelPreparer API first and then use the returned model as input to all the AIMET Quantization features.

AIMET PyTorch ModelPreparer API requires minimum PyTorch 1.9 version.

Top-level API

aimet_torch.model_preparer.prepare_model(model, concrete_args=None)

Prepare and modify the pytorch model for AIMET features using torch.FX symbolic tracing API.

#1 Replace torch.nn.functional by torch.nn.Module. #2 Create new independent torch.nn.Module instances for reused/duplicate module.

Example #1 Replace torch.nn.functional by torch.nn.module:

class ModelWithFunctionalReLU(torch.nn.Module):

    def __init__(self):
        super(ModelWithFunctionalReLU, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3)

    def forward(self, *inputs):
        x = self.conv1(inputs[0])
        x = torch.nn.functional.relu(x, inplace=True)
        return x

model = ModelWithFunctionalReLU().eval()
model_transformed = prepare_model(model)

This function can replace the ReLU of type torch.nn.functional by type torch.nn.Module and make sure both the modified and original model are functionally same.

Example #2 Create new module for reused/duplicate module:

class ModelWithDuplicateReLU(torch.nn.Module):

    def __init__(self):
        super(ModelWithDuplicateReLU, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3)
        self.relu = torch.nn.ReLU(inplace=True)

    def forward(self, *inputs):
        x = self.relu(inputs[0])
        x = self.conv1(x)
        x = self.relu(x)
        return x

model = ModelWithDuplicateReLU().eval()
model_transformed = prepare_model(model)

This function can create new independent torch.nn.ReLU type module for reused module and make sure both the modified and original model are functionally same.

Limitations of torch.fx symbolic trace API:

#1 Dynamic control flow where conditions depend on some of the input values. This limitation can be overcome by binding concrete values to arguments during symbolic tracing:

def f(x, flag):
    if flag: return x
    else: return x*2

torch.fx.symbolic_trace(f) # Fails!
torch.fx.symbolic_trace(f, concrete_args={'flag': True}) # Passes!

#2 Non-torch functions which does not use __torch_function__ mechanism is not supported by default in symbolic tracing. If we do not want to capture them in symbolic tracing then use torch.fx.wrap() API at module-scope level:

import torch
import torch.fx
torch.fx.wrap('len')  # call the API at module-level scope.
torch.fx.wrap('sqrt') # call the API at module-level scope.

class ModelWithNonTorchFunction(torch.nn.Module):
    def __init__(self):
        super(ModelWithNonTorchFunction, self).__init__()
        self.conv = torch.nn.Conv2d(3, 4, kernel_size=2, stride=2, padding=2, bias=False)

    def forward(self, *inputs):
        x = self.conv(inputs[0])
        return x / sqrt(len(x))

model = ModelWithNonTorchFunction().eval()
model_transformed = prepare_model(model)
Parameters
  • model (Module) – pytorch Model to be modified

  • concrete_args (Optional[Dict[str, Any]]) – Allows you to partially specialize your function, whether it’s to remove control flow or data structures. If the model has control flow, torch.fx won’t be able to trace the model. Check torch.fx.symbolic_trace API in detail.

Return type

GraphModule

Returns

Modified pytorch Model

Code Examples

Required imports


import torch
import torch.nn.functional as F
from aimet_torch.model_preparer import prepare_model

Example 1: Model with Functional relu

We begin with the following model, which contains two functional relus and relu method inside forward method.

class ModelWithFunctionalReLU(torch.nn.Module):
    """ Model that uses functional ReLU instead of nn.Modules. Expects input of shape (1, 3, 32, 32) """
    def __init__(self):
        super(ModelWithFunctionalReLU, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x).relu()
        return x

Run the model preparer API on the model by passing in the model.

def model_preparer_functional_example():

    # Load the model and keep in eval() mode
    model = ModelWithFunctionalReLU().eval()
    input_shape = (1, 3, 32, 32)
    input_tensor = torch.randn(*input_shape)

    # Call to prepare_model API
    prepared_model = prepare_model(model)
    print(prepared_model)

    # Compare the outputs of original and transformed model
    assert torch.allclose(model(input_tensor), prepared_model(input_tensor))

After that, we get prepared_model, which is functionally same as original model. User can verify this by comparing the outputs of both models.

prepared_model should have all three functional relus now converted to torch.nn.ReLU modules which satisfy model guidelines described here Model Guidelines.

Example 2: Model with reused torch.nn.ReLU module

We begin with the following model, which contains torch.nn.ReLU module which is used at multiple instances inside model forward function.

class ModelWithReusedReLU(torch.nn.Module):
    """ Model that uses single ReLU instances multiple times in the forward. Expects input of shape (1, 3, 32, 32) """
    def __init__(self):
        super(ModelWithReusedReLU, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.relu = torch.nn.ReLU()
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        return x

Run the model preparer API on the model by passing in the model.

def model_preparer_reused_example():

    # Load the model and keep in eval() mode
    model = ModelWithReusedReLU().eval()
    input_shape = (1, 3, 32, 32)
    input_tensor = torch.randn(*input_shape)

    # Call to prepare_model API
    prepared_model = prepare_model(model)
    print(prepared_model)

    # Compare the outputs of original and transformed model
    assert torch.allclose(model(input_tensor), prepared_model(input_tensor))

After that, we get prepared_model, which is functionally same as original model. User can verify this by comparing the outputs of both models.

prepared_model should have separate independent torch.nn.Module instances which satisfy model guidelines described here Model Guidelines.

Example 3: Model with elementwise Add

We begin with the following model, which contains elementwise Add operation inside model forward function.

class ModelWithElementwiseAddOp(torch.nn.Module):
    def __init__(self):
        super(ModelWithElementwiseAddOp, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5, bias=False)
        self.conv2 = torch.nn.Conv2d(3, 6, 5)

    def forward(self, *inputs):
        x1 = self.conv1(inputs[0])
        x2 = self.conv2(inputs[1])
        x = x1 + x2
        return x

Run the model preparer API on the model by passing in the model.

def model_preparer_elementwise_add_example():

    # Load the model and keep in eval() mode
    model = ModelWithElementwiseAddOp().eval()
    input_shape = (1, 3, 32, 32)
    input_tensor = [torch.randn(*input_shape), torch.randn(*input_shape)]

    # Call to prepare_model API
    prepared_model = prepare_model(model)
    print(prepared_model)

    # Compare the outputs of original and transformed model
    assert torch.allclose(model(*input_tensor), prepared_model(*input_tensor))

After that, we get prepared_model, which is functionally same as original model. User can verify this by comparing the outputs of both models.