Manual mixed precision¶
Context¶
To set the model in mixed precision, AIMET user would have to find the correct quantizer(s) and change to the new settings. This requires complex graph traversals which are error prone. Manual Mixed Precision (MMP) Configurator hides this issue by providing easy to use APIs to configure the model in mixed precision. User can change the precision of a layer by directly specifying the layer and the intended precision. User would also get a report to analyze how it was achieved.
MMP configurator provides the following mechanisms to change the precision in a model
Change the precision of a leaf layer
Change the precision of a non-leaf layer (layer composed of multiple leaf layers)
Change the precision of all the layers in the model of a certain type
Change the precision of model input tensors (or only a subset of input tensors)
Change the precision of model output tensors (or only a subset of output tensors)
Workflow¶
Setup¶
import torch
from torchvision.models import mobilenet_v2
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.v2.mixed_precision import MixedPrecisionConfigurator
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape).cuda()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = mobilenet_v2(pretrained=True).eval().to(device)
# create the sim object. Feel free to change the default settings as you wish
quant_sim = QuantizationSimModel(model,
dummy_input=dummy_input,
default_param_bw=8,
default_output_bw=8,
config_file=get_path_for_per_channel_config())
# create the MMP configurator object
mp_configurator = MixedPrecisionConfigurator(quant_sim)
MMP API options¶
MMP provides the following APIs to change the precision. The APIs can be called in any order. But, in case of conflicts, latest request will triumph the older request.
Note
The requests are processed using the leaf layers in the model
If one of the below APIs is called multiple times for the same layer but with a different precision in each of those calls, only the latest one would be serviced
This rule holds good even if the requests are from two different APIs ie if user calls a non-leaf layer (L1) with precision (P1) and a leaf layer inside L1 (L2) with precision (P2). This would be serviced by setting all the layers in L1 at P1 precision, except layer L2 which would be set at P2 precision.
Set precision of a leaf layer¶
mp_configurator.set_precision(quant_sim.model.features[1].conv[0][0], activation='int16', param={'weight': 'int16'})
Set precision of a non-leaf layer¶
mp_configurator.set_precision(quant_sim.model.features[3].conv[1], activation='int16', param={'weight': 'int16'})
Set precision based on layer type¶
mp_configurator.set_precision(torch.nn.AvgPool2d, activation='int16')
Set model input precision¶
mp_configurator.set_model_input_precision(activation='int16')
Do note that if a model has more than one input tensor (say the structure is [In1, In2]), but only one of them (say In2) needs to be configured to a new precision (say P1), user can achieve it by setting
activation=[None, P1]
in the above API
Set model output precision¶
mp_configurator.set_model_output_precision(activation='int16')
Do note that if a model has more than one output tensor (say the structure is [Out1, Out2, Out3]), but only one of them (say Out2) needs to be configured to a new precision (say P1), user can achieve it by setting
activation=[None, P1, None]
in the above API
Apply the profile¶
All the above set precision family of calls would be processed at once when the below apply(...)
API is called
mp_configurator.apply()
Note
The above call would generate a report detailing how a user’s request was inferred, propagated to other layers and realized eventually
API¶
Top-level API for Manual mixed precision
- class aimet_torch.v2.mixed_precision.MixedPrecisionConfigurator(sim)[source]¶
Mixed Precision Configurator helps set up a mixed precision profile in the QuantSim object. The user is expected to follow the below steps to set the sim in Mixed Precision.
Create QuantSim object
Create the MixedPrecisionConfigurator object by passing in the QuantSim object
Make a series of set_precision/set_model_input_precision/set_model_output_precision calls
Call apply() method by passing in the config file and strict flag
Run compute_encodings on the above QuantSim object
Export the encodings/onnx artifacts
- Parameters:
sim (
QuantizationSimModel
) – QuantSim object
- set_precision(arg, activation, param=None)[source]¶
- Parameters:
arg (
Union
[Module
,Type
[Module
]]) – Module can be of type torch.nn.Module or the type of the module.activation (
Union
[List
[Literal
['int16'
,'int8'
,'int4'
,'fp16'
]],Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]) – A string representing the activation dtype of the module input(s)param (
Optional
[Dict
[str
,Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]]) – Dict with name of the param as key and its dtype as value
If the ‘module’ is a leaf-module(the module doesnt compose of other torch.nn.module), the specified settings would be applied to the module.
If the ‘module’ is a non-leaf-module (module is composed of other torch.nn.module), the specified settings would be applied to all the leaf modules in ‘module’.
If the ‘module’ is Type of module, all the modules in the model which satisfy the specified module type would be set to the specified activation and param settings
If the same ‘module’ is specified through multiple set_precision(…) calls, the latest one will be applied.
Examples: TODO
- set_model_input_precision(activation)[source]¶
Activation precision which needs to be set to the model inputs
- Parameters:
activation (
Union
[List
[Optional
[Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]],Tuple
[Optional
[Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]],Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]) – Activation dtypes for inputs of the model
- set_model_output_precision(activation)[source]¶
Activation precision which needs to be set to the model outputs
- Parameters:
activation (
Union
[List
[Optional
[Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]],Tuple
[Optional
[Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]],Literal
['int16'
,'int8'
,'int4'
,'fp16'
]]) – Activation dtypes for outputs of the model
- apply(log_file='./mmp_log.txt', strict=True)[source]¶
Apply the mp settings specified through the set_precision/set_model_input_precision/set_model_output_precision calls to the QuantSim object
- Parameters:
log_file (
Union
[IO
,str
,None
]) – log_file to store the logs. log_file can either be a string representing the path or the IO object to write the logs into.strict (
bool
) – Boolean flag to indicate whether to fail (strict=True) on incorrect/conflicting inputs made by the user or (strict=False) take a best-effort approach to realize the MP settings