Manual mixed precision¶
Context¶
To effectively use mixed precision, you must find the correct quantizers to run at higher precision settings. This requires complex, error-prone graph traversals. The AIMET manual mixed precision (MMP) configurator hides this issue by providing easy-to-use APIs to configure the model in mixed precision. You can change the precision of a layer by directly specifying the layer and the intended precision. MMP configurator also analyzes and reports how the mixed precision was achieved.
MMP configurator enables you to change the precision of the following within a model:
A leaf layer
A non-leaf layer (a layer composed of multiple leaf layers)
All layers of a certain type
Model input tensors or a subset of input tensors
Model output tensors or a subset of output tensors
Workflow¶
Prerequisites¶
Manual mixed precision is supported only on PyTorch models.
Setup¶
import torch
from torchvision.models import mobilenet_v2
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.v2.mixed_precision import MixedPrecisionConfigurator
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape).cuda()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(pretrained=True).eval().to(device)
# create the sim object. Feel free to change the default settings as you wish
quant_sim = QuantizationSimModel(model,
dummy_input=dummy_input,
default_param_bw=8,
default_output_bw=8,
config_file=get_path_for_per_channel_config())
# create the MMP configurator object
mp_configurator = MixedPrecisionConfigurator(quant_sim)
Not supported.
Not supported.
Step 1: Applying MMP API options¶
Note
All requests are processed using the leaf layers in the model.
MMP provides the following APIs to change layers’ precision. The APIs can be called in any order. In case of conflicts, the latest request overrides an older request. For example:
If one of the following APIs is called multiple times but with a different precision for the same layer, only the latest call is serviced.
The last request takes precedence even if the requests are from two different APIs. For example, say you call a non-leaf layer L1 with precision P1 and then a leaf layer L2, inside L1, with precision P2. This sets all the layers in L1 to precision P1, except layer L2 which is set to P2.
Set precision of a leaf layer¶
mp_configurator.set_precision(quant_sim.model.features[1].conv[0][0], activation='int16', param={'weight': 'int16'})
Not supported.
Not supported.
Set precision of a non-leaf layer¶
mp_configurator.set_precision(quant_sim.model.features[3].conv[1], activation='int16', param={'weight': 'int16'})
Not supported.
Not supported.
Set precision based on layer type¶
mp_configurator.set_precision(torch.nn.AvgPool2d, activation='int16')
Not supported.
Not supported.
Set model input precision¶
mp_configurator.set_model_input_precision(activation='int16')
Not supported.
Not supported.
If a model has more than one input tensor (for example, the structure is [In1, In2]), you can set just one of them (say In2) to a new precision (say P1) by setting activation=[None, P1]
in the above API.
Set model output precision¶
mp_configurator.set_model_output_precision(activation='int16')
Not supported.
Not supported.
If a model has more than one output tensor (for example, the structure is [Out1, Out2, Out3]), you can set just one of them (say Out2) to a new precision (say P1) by setting activation=[None, P1, None]
in the above API.
Step 2: Applying the profile¶
All of the set precision family of calls from step 1 are processed at once when the following apply(...)
API is called.
mp_configurator.apply()
Not supported.
Not supported.
The apply
call generates a report detailing how the request was inferred, propagated to other layers, and eventually realized.
API¶
Top-level API for Manual mixed precision
- class aimet_torch.v2.mixed_precision.MixedPrecisionConfigurator(sim)[source]¶
Mixed Precision Configurator helps set up a mixed precision profile in the QuantSim object. The user is expected to follow the below steps to set the sim in Mixed Precision.
Create QuantSim object
Create the MixedPrecisionConfigurator object by passing in the QuantSim object
Make a series of set_precision/set_model_input_precision/set_model_output_precision calls
Call apply() method by passing in the config file and strict flag
Run compute_encodings on the above QuantSim object
Export the encodings/onnx artifacts
- Parameters:
sim (
QuantizationSimModel
) – QuantSim object
- set_precision(arg, activation, param=None)[source]¶
- Parameters:
arg (
Union
[Module
,Type
[Module
]]) – Module can be of type torch.nn.Module or the type of the module.activation (
Union
[List
[Literal
['int16'
,'int8'
,'int4'
,'fp16'
]],Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]) – A string representing the activation dtype of the module input(s)param (
Optional
[Dict
[str
,Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]]) – Dict with name of the param as key and its dtype as value
If the ‘module’ is a leaf-module(the module doesnt compose of other torch.nn.module), the specified settings would be applied to the module.
If the ‘module’ is a non-leaf-module (module is composed of other torch.nn.module), the specified settings would be applied to all the leaf modules in ‘module’.
If the ‘module’ is Type of module, all the modules in the model which satisfy the specified module type would be set to the specified activation and param settings
If the same ‘module’ is specified through multiple set_precision(…) calls, the latest one will be applied.
Examples: TODO
- set_model_input_precision(activation)[source]¶
Activation precision which needs to be set to the model inputs
- Parameters:
activation (
Union
[List
[Optional
[Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]],Tuple
[Optional
[Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]],Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]) – Activation dtypes for inputs of the model
- set_model_output_precision(activation)[source]¶
Activation precision which needs to be set to the model outputs
- Parameters:
activation (
Union
[List
[Optional
[Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]],Tuple
[Optional
[Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]],Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]) – Activation dtypes for outputs of the model
- apply(log_file='./mmp_log.txt', strict=True)[source]¶
Apply the mp settings specified through the set_precision/set_model_input_precision/set_model_output_precision calls to the QuantSim object
- Parameters:
log_file (
Union
[IO
,str
,None
]) – log_file to store the logs. log_file can either be a string representing the path or the IO object to write the logs into.strict (
bool
) – Boolean flag to indicate whether to fail (strict=True) on incorrect/conflicting inputs made by the user or (strict=False) take a best-effort approach to realize the MP settings
Not supported.
Not supported.