AIMET Tensorflow Quant Analyzer API

AIMET Tensorflow Quant Analyzer analyzes the Tensorflow model and points out sensitive ops to quantization in the model. It checks model sensitivity to weight and activation quantization, performs per op sensitivity and MSE analysis. It also exports per op encodings min and max ranges and statistics histogram for every op.

Top-level API

class aimet_tensorflow.quant_analyzer.QuantAnalyzer(session, start_op_names, output_op_names, forward_pass_callback, eval_callback, use_cuda=True)[source]
QuantAnalyzer tool provides
  1. Model sensitivity to weight and activation quantization

  2. Per layer encoding (min - max range) and PDF analysis

  3. Per op sensitivity analysis

  4. Per op MSE analysis

Parameters
  • session (Session) – The input model as session to add quantize ops to

  • start_op_names (List[str]) – List of starting op names of the model

  • output_op_names (List[str]) – List of output op names of the model

  • forward_pass_callback (CallbackFunc) – A callback function that is expected to run forward passes on a session. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. This callback internally chooses the number of data samples it wants to use for calculating encodings.

  • eval_callback (CallbackFunc) – A callback function for model evaluation that determines model performance. This callback function is expected to return scalar value representing the model performance evaluated against entire test/evaluation dataset.

  • use_cuda (bool) – If True, places quantization ops on GPU. Defaults to True

analyze(quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, rounding_mode='nearest', default_param_bw=8, default_output_bw=8, config_file=None, unlabeled_dataset=None, num_batches=None, results_dir='./tmp/')[source]
Analyze model for quantization and point out sensitive parts/hotspots of the model by performing
  1. model sensitivity to quantization

  2. export per layer encoding (min - max range)

  3. export per layer statistics histogram (PDF) when quant scheme is TF-Enhanced

  4. perform per op sensitivity analysis by enabling and disabling quant ops

  5. per op MSE loss between fp32 and quantized output activations

Parameters
  • quant_scheme (QuantScheme) – Quantization Scheme, currently supported schemes are post_training_tf and post_training_tf_enhanced, defaults to post_training_tf_enhanced

  • rounding_mode (str) – The round scheme to used. One of: ‘nearest’ or ‘stochastic’, defaults to ‘nearest’

  • default_param_bw (int) – bitwidth to use for parameter tensors, defaults to 8

  • default_output_bw (int) – bitwidth to use for activation tensors, defaults to 8

  • config_file (Optional[str]) – Path to a config file to use to specify rules for placing quant ops in the model

  • results_dir (str) – Directory to save the results.

  • unlabeled_dataset (Optional[DatasetV1]) – Unlabeled TF dataset Used in per op MSE loss calculation

  • num_batches (Optional[int]) – Number of batches. Approximately 256 samples/images are recommended, so if batch size of data loader is 64, then 4 number of batches leads to 256 samples/images Used in per op MSE loss calculation

Code Example

Required imports

from typing import Any
import numpy as np
import tensorflow.compat.v1 as tf
from aimet_common.defs import QuantScheme
from aimet_tensorflow.quant_analyzer import QuantAnalyzer, CallbackFunc
# Below import is required just for eval_callback utility
# User can have their own implementation
from Examples.tensorflow.utils.image_net_evaluator import ImageNetEvaluator

Prepare forward pass callback

def forward_pass_callback(session: tf.compat.v1.Session, _: Any = None) -> None:
    """
    NOTE: This is intended to be the user-defined model calibration function.
    AIMET requires the above signature. So if the user's calibration function does not
    match this signature, please create a simple wrapper around this callback function.

    A callback function for model calibration that simply runs forward passes on the model to
    compute encoding (delta/offset). This callback function should use representative data and should
    be subset of entire train/validation dataset (~1000 images/samples).

    :param model: Tensorflow session.
    :param _: Argument(s) of this callback function. Up to the user to determine the type of this parameter.
    E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of
    parameters or an object representing something more complex.
    """

Prepare eval callback

def eval_callback(session: tf.compat.v1.Session, _: Any = None) -> float:
    """
    NOTE: This is intended to be the user-defined model evaluation function.
    AIMET requires the above signature. So if the user's calibration function does not
    match this signature, please create a simple wrapper around this callback function.

    A callback function for model evaluation that determines model performance. This callback function is
    expected to return scalar value representing the model performance evaluated against entire
    test/evaluation dataset.

    :param model: Tensorflow session.
    :param _: Argument(s) of this callback function. Up to the user to determine the type of this parameter.
    E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of
    parameters or an object representing something more complex.
    :return: Scalar value representing the model performance.
    """
    # User action required
    # User should create data loader/iterable using entire test/evaluation dataset, perform forward passes on
    # the model and return single scalar value representing the model performance.
    evaluator = ImageNetEvaluator('/path/to/tfrecords dataset', training_inputs=['keras_learning_phase:0'],
                                  data_inputs=['input_1:0'], validation_inputs=['labels:0'],
                                  image_size=224,
                                  batch_size=32,
                                  format_bgr=True)
    evaluator.evaluate(session, iterations=None)

Create session

    # User action required
    # User should create a tf session from the model which has to be analyzed
    session = tf.compat.v1.Session()
    # User has to define start_op_names and output_op_names of model
    start_op_names = ['model_start_op_name']
    output_op_names = ['model_output_op_name']
    # User action required
    # User should pass actual argument(s) of the callback functions.
    forward_pass_callback_fn = CallbackFunc(forward_pass_callback, func_callback_args=None)
    eval_callback_fn = CallbackFunc(eval_callback, func_callback_args=None)

Create QuantAnalyzer object

    quant_analyzer = QuantAnalyzer(session, start_op_names=start_op_names, output_op_names=output_op_names,
                                   forward_pass_callback=forward_pass_callback_fn, eval_callback=eval_callback_fn, use_cuda= False)

Create unlabeled dataset and define num_batches

    # Create unlabeled dataset and define num_batches to perform per op mse analysis
    # Create unlabeled dataset and define num_batches to perform per op mse analysis
    # Approximately 256 images/samples are recommended for MSE loss analysis. So, if the dataloader
    # has batch_size of 64, then 4 number of batches leads to 256 images/samples.
    # User action required
    # User should use unlabeled dataloader, so if the dataloader yields labels as well user should use discard them.
    dataset_size = 128
    input_data = np.random.rand(dataset_size, 224, 224, 3)
    dataset = tf.data.Dataset.from_tensor_slices(input_data)
    batch_size = 32
    unlabeled_dataset = dataset.batch(batch_size=batch_size)
    num_batches = 4

Run QuantAnalyzer

    quant_analyzer.analyze(quant_scheme=QuantScheme.post_training_tf_enhanced,
                           default_param_bw=8,
                           default_output_bw=8,
                           config_file=None,
                           unlabeled_dataset=unlabeled_dataset,
                           num_batches=num_batches,
                           results_dir="./quant_analyzer_results/")