AIMET PyTorch Quant Analyzer API

AIMET PyTorch Quant Analyzer analyzes the PyTorch model and points out sensitive layers to quantization in the model. It checks model sensitivity to weight and activation quantization, performs per layer sensitivity and MSE analysis. It also exports per layer encodings min and max ranges and statistics histogram for every layer.

Top-level API

Note

This module is also available in the experimental aimet_torch.v2 namespace with the same top-level API. To learn more about the differences between aimet_torch and aimet_torch.v2, please visit the QuantSim v2 Overview.

class aimet_torch.quant_analyzer.QuantAnalyzer(model, dummy_input, forward_pass_callback, eval_callback, modules_to_ignore=None)[source]

QuantAnalyzer tool provides

  1. model sensitivity to weight and activation quantization

  2. per layer sensitivity analysis

  3. per layer encoding (min - max range)

  4. per PDF analysis and

  5. per layer MSE analysis

Parameters:
  • model (Module) – FP32 model to analyze for quantization.

  • dummy_input (Union[Tensor, Tuple]) – Dummy input to model.

  • forward_pass_callback (CallbackFunc) – A callback function for model calibration that simply runs forward passes on the model to compute encoding (delta/offset). This callback function should use representative data and should be subset of entire train/validation dataset (~1000 images/samples).

  • eval_callback (CallbackFunc) – A callback function for model evaluation that determines model performance. This callback function is expected to return scalar value representing the model performance evaluated against entire test/evaluation dataset.

  • modules_to_ignore (Optional[List[Module]]) – Excludes certain modules from being analyzed.


QuantAnalyzer.enable_per_layer_mse_loss(unlabeled_dataset_iterable, num_batches)[source]

Enable per layer MSE loss analysis.

Parameters:
  • unlabeled_dataset_iterable (Union[DataLoader, Collection]) – A collection (i.e. iterable with __len__) that iterates over an unlabeled dataset. The values yielded by this iterable are expected to be able to be passed directly to the model.

  • num_batches (int) – Number of batches. Approximately 256 samples/images are recommended, so if batch size of data loader is 64, then 4 number of batches leads to 256 samples/images.


QuantAnalyzer.analyze(quant_scheme=QuantScheme.post_training_tf_enhanced, default_param_bw=8, default_output_bw=8, config_file=None, results_dir='./tmp/')[source]
Analyze model for quantization and point out sensitive parts/hotspots of the model by performing
  1. model sensitivity to quantization,

  2. perform per layer sensitivity analysis by enabling and disabling quant wrappers,

  3. export per layer encodings min - max ranges,

  4. export per layer statistics histogram (PDF) when quant scheme is TF-Enhanced,

  5. per layer MSE analysis

Parameters:
  • quant_scheme (QuantScheme) – Quantization scheme. Supported values are QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced.

  • default_param_bw (int) – Default bitwidth (4-31) to use for quantizing layer parameters.

  • default_output_bw (int) – Default bitwidth (4-31) to use for quantizing layer inputs and outputs.

  • config_file (Optional[str]) – Path to configuration file for model quantizers.

  • results_dir (str) – Directory to save the results.


class aimet_common.utils.CallbackFunc(func, func_callback_args=None)[source]

Class encapsulating callback function, and it’s argument(s)

Parameters:
  • func (Callable) – Callable Function

  • func_callback_args – Arguments passed to the callable function as-is.


Run specific utility

We can avoid running all the utilities that QuantAnalyzer offers and only run those of our interest. For this we need to have the QuantizationSimModel object, Then we call the desired QuantAnalyzer utility of our interest and pass the same object to it.

QuantAnalyzer.check_model_sensitivity_to_quantization(sim)[source]

Perform the sensitivity analysis to weight and activation quantization individually.

Parameters:

sim (QuantizationSimModel) – Quantsim model.

Return type:

Tuple[float, float, float]

Returns:

FP32 eval score, weight-quantized eval score, act-quantized eval score.


QuantAnalyzer.perform_per_layer_analysis_by_enabling_quant_wrappers(sim, results_dir)[source]

NOTE: Option 1

  1. All quant wrappers’ parameters and activations quantizers are disabled.

  2. Based on occurrence for every quant wrappers
    • Each quant wrapper’s parameters and activations quantizers are enabled as per JSON config file and set to bit-width specified.

    • Measure and record eval score on subset of dataset.

    • Disable enabled quantizers in step 1.

  3. Returns dictionary containing quant wrapper name and corresponding eval score.

Parameters:
  • sim (QuantizationSimModel) – Quantsim model.

  • results_dir (str) – Directory to save the results.

Return type:

Dict

Returns:

layer wise eval score dictionary. dict[layer_name] = eval_score


QuantAnalyzer.perform_per_layer_analysis_by_disabling_quant_wrappers(sim, results_dir)[source]

NOTE: Option 2

  1. All quant wrappers’ parameters and activations quantizers are enabled as per JSON config file and set to bit-width specified.

  2. Based on occurrence for every quant wrappers
    • Each quant wrapper’s parameters and activations quantizers are disabled.

    • Measure and record eval score on subset of dataset.

    • Enable disabled quantizers in step 1.

  3. Returns dictionary containing quant wrapper name and corresponding eval score.

Parameters:
  • sim (QuantizationSimModel) – Quantsim model.

  • results_dir (str) – Directory to save the results.

Return type:

Dict

Returns:

layer wise eval score dictionary. dict[layer_name] = eval_score


QuantAnalyzer.export_per_layer_encoding_min_max_range(sim, results_dir)[source]

Export encoding min and max range for all weights and activations. results_dir should have html files in following format.

-results_dir

-activations.html -weights.html

If per channel quantization(PCQ) is enabled then,

-results_dir

-activations.html -{wrapped_module_name}_{param_name}.html

Parameters:
  • sim (QuantizationSimModel) – Quantsim model.

  • results_dir (str) – Directory to save the results.

Return type:

Tuple[Dict, Dict]

Returns:

layer wise min-max range for weights and activations.


QuantAnalyzer.export_per_layer_stats_histogram(sim, results_dir)[source]

NOTE: Not to invoke when quantization scheme is not TF-Enhanced.

Export histogram that represents a PDF of collected statistics by a quantizer for every quant wrapper. After invoking this API, results_dir should have html files in following format for every quantizers of quant wrappers.

-results_dir
-activations_pdf

name_{input/output}_{index}.html

-weights_pdf
-name

param_name_{channel_index}.html

Parameters:
  • sim (QuantizationSimModel) – Quantsim model.

  • results_dir (str) – Directory to save the results.


QuantAnalyzer.export_per_layer_mse_loss(sim, results_dir)[source]

NOTE: Need to pass same model input data through both fp32 and quantsim model to tap output activations of each layer.

Export MSE loss between fp32 and quantized output activations for each layer. :type sim: QuantizationSimModel :param sim: Quantsim model. :type results_dir: str :param results_dir: Directory to save the results. :return layer wise MSE loss. dict[layer_name] = MSE loss.

Return type:

Dict


Code Examples

Required imports

from typing import Any
import torch
from torchvision import models
from aimet_common.defs import QuantScheme
from aimet_torch.model_preparer import prepare_model
from aimet_torch.quant_analyzer import QuantAnalyzer, CallbackFunc

Prepare forward pass callback

# NOTE: In the actual use cases, the users should implement this part to serve
#       their own goals if necessary.
def forward_pass_callback(model: torch.nn.Module, _: Any = None) -> None:
    """
    NOTE: This is intended to be the user-defined model calibration function.
    AIMET requires the above signature. So if the user's calibration function does not
    match this signature, please create a simple wrapper around this callback function.

    A callback function for model calibration that simply runs forward passes on the model to
    compute encoding (delta/offset). This callback function should use representative data and should
    be subset of entire train/validation dataset (~1000 images/samples).

    :param model: PyTorch model.
    :param _: Argument(s) of this callback function. Up to the user to determine the type of this parameter.
    E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of
    parameters or an object representing something more complex.
    """
    # User action required
    # User should create data loader/iterable using representative dataset and simply run
    # forward passes on the model.

Prepare eval callback

# NOTE: In the actual use cases, the users should implement this part to serve
#       their own goals if necessary.
def eval_callback(model: torch.nn.Module, _: Any = None) -> float:
    """
    NOTE: This is intended to be the user-defined model evaluation function.
    AIMET requires the above signature. So if the user's calibration function does not
    match this signature, please create a simple wrapper around this callback function.

    A callback function for model evaluation that determines model performance. This callback function is
    expected to return scalar value representing the model performance evaluated against entire
    test/evaluation dataset.

    :param model: PyTorch model.
    :param _: Argument(s) of this callback function. Up to the user to determine the type of this parameter.
    E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of
    parameters or an object representing something more complex.
    :return: Scalar value representing the model performance.
    """
    # User action required
    # User should create data loader/iterable using entire test/evaluation dataset, perform forward passes on
    # the model and return single scalar value representing the model performance.
    return .8

Prepare model and callback functions

    model = models.resnet18(pretrained=True).cuda().eval()
    input_shape = (1, 3, 224, 224)
    dummy_input = torch.randn(*input_shape).cuda()
    prepared_model = prepare_model(model)

    # User action required
    # User should pass actual argument(s) of the callback functions.
    forward_pass_callback_fn = CallbackFunc(forward_pass_callback, func_callback_args=None)
    eval_callback_fn = CallbackFunc(eval_callback, func_callback_args=None)

Create QuantAnalyzer object

    quant_analyzer = QuantAnalyzer(model=prepared_model,
                                   dummy_input=dummy_input,
                                   forward_pass_callback=forward_pass_callback_fn,
                                   eval_callback=eval_callback_fn)

    # User action required
    # User should use unlabeled dataloader, so if the dataloader yields labels as well user should use discard them.
    unlabeled_data_loader = _get_unlabled_data_loader()
    # Approximately 256 images/samples are recommended for MSE loss analysis. So, if the dataloader
    # has batch_size of 64, then 4 number of batches leads to 256 images/samples.
    quant_analyzer.enable_per_layer_mse_loss(unlabeled_dataset_iterable=unlabeled_data_loader, num_batches=4)

Run QuantAnalyzer

    quant_analyzer.analyze(quant_scheme=QuantScheme.post_training_tf_enhanced,
                           default_param_bw=8,
                           default_output_bw=8,
                           config_file=None,
                           results_dir="./quant_analyzer_results/")