Model Preparer API
AIMET PyTorch ModelPreparer API uses new graph transformation feature available in PyTorch 1.9+ version and automates model definition changes required by user. For example, it changes functionals defined in forward pass to torch.nn.Module type modules for activation and elementwise functions. Also, when torch.nn.Module type modules are reused, it unrolls into independent modules.
Users are strongly encouraged to use AIMET PyTorch ModelPreparer API first and then use the returned model as input to all the AIMET Quantization features.
AIMET PyTorch ModelPreparer API requires minimum PyTorch 1.9 version.
Top-level API
- aimet_torch.model_preparer.prepare_model(model, modules_to_exclude=None, module_classes_to_exclude=None, concrete_args=None)[source]
Prepare and modify the pytorch model for AIMET features using torch.FX symbolic tracing API.
Replace torch.nn.functional by module of type torch.nn.Module
Create new independent torch.nn.Module instances for reused/duplicate module
- Parameters:
model (
Module
) – pytorch Model to be modified.modules_to_exclude (
Optional
[List
[Module
]]) – List of modules to exclude when tracing.module_classes_to_exclude (
Optional
[List
[Callable
]]) – List of module classes to exclude when tracing.concrete_args (
Optional
[Dict
[str
,Any
]]) – Allows you to partially specialize your function, whether it’s to remove control flow or data structures. If the model has control flow, torch.fx won’t be able to trace the model. Check torch.fx.symbolic_trace API in detail.
- Return type:
GraphModule
- Returns:
Modified pytorch Model
Code Examples
Required imports
import torch
import torch.nn.functional as F
from aimet_torch.model_preparer import prepare_model
Example 1: Model with Functional relu
We begin with the following model, which contains two functional relus and relu method inside forward method.
class ModelWithFunctionalReLU(torch.nn.Module):
""" Model that uses functional ReLU instead of nn.Modules. Expects input of shape (1, 3, 32, 32) """
def __init__(self):
super(ModelWithFunctionalReLU, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 6, 5)
self.conv2 = torch.nn.Conv2d(6, 16, 5)
self.fc1 = torch.nn.Linear(9216, 128)
self.fc2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x).relu()
return x
Run the model preparer API on the model by passing in the model.
def model_preparer_functional_example():
# Load the model and keep in eval() mode
model = ModelWithFunctionalReLU().eval()
input_shape = (1, 3, 32, 32)
input_tensor = torch.randn(*input_shape)
# Call to prepare_model API
prepared_model = prepare_model(model)
print(prepared_model)
# Compare the outputs of original and transformed model
assert torch.allclose(model(input_tensor), prepared_model(input_tensor))
After that, we get prepared_model, which is functionally same as original model. User can verify this by comparing the outputs of both models.
prepared_model should have all three functional relus now converted to torch.nn.ReLU modules which satisfy model guidelines described here Model Guidelines.
Example 2: Model with reused torch.nn.ReLU module
We begin with the following model, which contains torch.nn.ReLU module which is used at multiple instances inside model forward function.
class ModelWithReusedReLU(torch.nn.Module):
""" Model that uses single ReLU instances multiple times in the forward. Expects input of shape (1, 3, 32, 32) """
def __init__(self):
super(ModelWithReusedReLU, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 6, 5)
self.conv2 = torch.nn.Conv2d(6, 16, 5)
self.relu = torch.nn.ReLU()
self.fc1 = torch.nn.Linear(9216, 128)
self.fc2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
return x
Run the model preparer API on the model by passing in the model.
def model_preparer_reused_example():
# Load the model and keep in eval() mode
model = ModelWithReusedReLU().eval()
input_shape = (1, 3, 32, 32)
input_tensor = torch.randn(*input_shape)
# Call to prepare_model API
prepared_model = prepare_model(model)
print(prepared_model)
# Compare the outputs of original and transformed model
assert torch.allclose(model(input_tensor), prepared_model(input_tensor))
After that, we get prepared_model, which is functionally same as original model. User can verify this by comparing the outputs of both models.
prepared_model should have separate independent torch.nn.Module instances which satisfy model guidelines described here Model Guidelines.
Example 3: Model with elementwise Add
We begin with the following model, which contains elementwise Add operation inside model forward function.
class ModelWithElementwiseAddOp(torch.nn.Module):
def __init__(self):
super(ModelWithElementwiseAddOp, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 6, 5, bias=False)
self.conv2 = torch.nn.Conv2d(3, 6, 5)
def forward(self, *inputs):
x1 = self.conv1(inputs[0])
x2 = self.conv2(inputs[1])
x = x1 + x2
return x
Run the model preparer API on the model by passing in the model.
def model_preparer_elementwise_add_example():
# Load the model and keep in eval() mode
model = ModelWithElementwiseAddOp().eval()
input_shape = (1, 3, 32, 32)
input_tensor = [torch.randn(*input_shape), torch.randn(*input_shape)]
# Call to prepare_model API
prepared_model = prepare_model(model)
print(prepared_model)
# Compare the outputs of original and transformed model
assert torch.allclose(model(*input_tensor), prepared_model(*input_tensor))
After that, we get prepared_model, which is functionally same as original model. User can verify this by comparing the outputs of both models.
Limitations of torch.fx symbolic trace API
Limitations of torch.fx symbolic trace: https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing
1. Dynamic control flow is not supported by torch.fx Loops or if-else statement where condition may depend on some of the input values. It can only trace one execution path and all the other branches that weren’t traced will be ignored. For example, following simple function when traced, will fail with TraceError saying that ‘symbolically traced variables cannot be used as inputs to control flow’:
def f(x, flag):
if flag:
return x
else:
return x*2
torch.fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={'flag': True})
Workarounds for this problem:
Many cases of dynamic control flow can be simply made to static control flow which is supported by torch.fx symbolic tracing. Static control flow is where loops or if-else statements whose value can’t change across different model forward passes. Such cases can be traced by removing data dependencies on input values by passing concrete values to ‘concrete_args’ to specialize your forward functions.
In truly dynamic control flow, user should wrap such piece of code at model-level scope using torch.fx.wrap API which will preserve it as a node instead of being traced through:
@torch.fx.wrap def custom_function_not_to_be_traced(x, y): """ Function which we do not want to be traced, when traced using torch FX API, call to this function will be inserted as call_function, and won't be traced through """ for i in range(2): x += x y += y return x * x + y * y
2. Non-torch functions which does not use __torch_function__ mechanism is not supported by default in symbolic tracing.
Workaround for this problem:
If we do not want to capture them in symbolic tracing then user should use torch.fx.wrap() API at module-level scope:
import torch import torch.fx torch.fx.wrap('len') # call the API at module-level scope. torch.fx.wrap('sqrt') # call the API at module-level scope. class ModelWithNonTorchFunction(torch.nn.Module): def __init__(self): super(ModelWithNonTorchFunction, self).__init__() self.conv = torch.nn.Conv2d(3, 4, kernel_size=2, stride=2, padding=2, bias=False) def forward(self, *inputs): x = self.conv(inputs[0]) return x / sqrt(len(x)) model = ModelWithNonTorchFunction().eval() model_transformed = prepare_model(model)
3. Customizing the behavior of tracing by overriding the Tracer.is_leaf_module() API
In symbolic tracing, leaf modules appears as node rather than being traced through and all the standard torch.nn modules are default set of leaf modules. But this behavior can be changed by overriding the Tracer.is_leaf_module() API.
AIMET model preparer API exposes ‘module_to_exclude’ argument which can be used to prevent certain module(s) being traced through. For example, let’s examine following code snippet where we don’t want to trace CustomModule further:
class CustomModule(torch.nn.Module):
@staticmethod
def forward(x):
return x * torch.nn.functional.softplus(x).sigmoid()
class CustomModel(torch.nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=2)
self.custom = CustomModule()
def forward(self, inputs):
x = self.conv1(inputs)
x = self.custom(x)
return x
model = CustomModel().eval()
prepared_model = prepare_model(model, modules_to_exclude=[model.custom])
print(prepared_model)
In this example, ‘self.custom’ is preserved as node and not being traced through.
4. Tensor constructors are not traceable
For example, let’s examine following code snippet:
def f(x):
return torch.arange(x.shape[0], device=x.device)
torch.fx.symbolic_trace(f)
Error traceback:
return torch.arange(x.shape[0], device=x.device)
TypeError: arange() received an invalid combination of arguments - got (Proxy, device=Attribute), but expected one of:
* (Number end, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
* (Number start, Number end, Number step, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
The above snippet is problematic because arguments to torch.arange() are input dependent. Workaround for this problem:
use deterministic constructors (hard-coding) so that the value they produce will be embedded as constant in the graph:
def f(x): return torch.arange(10, device=torch.device('cpu'))
Or use torch.fx.wrap API to wrap torch.arange() and call that instead:
@torch.fx.wrap def do_not_trace_me(x): return torch.arange(x.shape[0], device=x.device) def f(x): return do_not_trace_me(x) torch.fx.symbolic_trace(f)