AIMET Keras Quant Analyzer API

AIMET Keras Quant Analyzer analyzes the Keras model and points out sensitive layers to quantization in the model. It checks model sensitivity to weight and activation quantization, performs per layer sensitivity and MSE analysis. It also exports per layer encodings min and max ranges and statistics histogram for every layer.

Top-level API

class aimet_tensorflow.keras.quant_analyzer.QuantAnalyzer(model, forward_pass_callback, eval_callback)[source]

QuantAnalyzer tool provides

  1. model sensitivity to weight and activation quantization

  2. per layer sensitivity analysis

  3. per layer encoding (min - max range)

  4. per PDF analysis and

  5. per layer MSE analysis

Parameters
  • model (Model) – FP32 model to analyze for quantization.

  • forward_pass_callback (CallbackFunc) – A callback function for model calibration that simply runs forward passes on the model to compute encoding (delta/offset). This callback function should use representative data and should be subset of entire train/validation dataset (~1000 images/samples).

  • eval_callback (CallbackFunc) – A callback function for model evaluation that determines model performance. This callback function is expected to return scalar value representing the model performance evaluated against entire test/evaluation dataset.

analyze(quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, rounding_mode='nearest', default_param_bw=8, default_output_bw=8, config_file=None, results_dir='./tmp/')[source]
Analyze model for quantization and point out sensitive parts/hotspots of the model by performing
  1. model sensitivity to quantization,

  2. perform per layer sensitivity analysis by enabling and disabling quant wrappers,

  3. export per layer encodings min - max ranges,

  4. export per layer statistics histogram (PDF) when quant scheme is TF-Enhanced,

  5. per layer MSE analysis

Parameters
  • quant_scheme (QuantScheme) – Quantization scheme. Supported values are QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced.

  • rounding_mode (str) – The round scheme to used. One of: ‘nearest’ or ‘stochastic’, defaults to ‘nearest’

  • default_param_bw (int) – Default bitwidth (4-31) to use for quantizing layer parameters.

  • default_output_bw (int) – Default bitwidth (4-31) to use for quantizing layer inputs and outputs.

  • config_file (Optional[str]) – Path to configuration file for model quantizers.

  • results_dir (str) – Directory to save the results.

Code Examples

Required imports

from typing import Any

import numpy as np
import tensorflow as tf

from aimet_common.defs import QuantScheme
from aimet_common.utils import CallbackFunc
from aimet_tensorflow.keras.model_preparer import prepare_model
from aimet_tensorflow.keras.quant_analyzer import QuantAnalyzer

Prepare toy dataset to run example code

NUM_SAMPLES = 256
NUM_CLASSES = 1000
INPUT_SHAPES = (224, 224, 3)

images = np.random.rand(NUM_SAMPLES, *INPUT_SHAPES)
labels = np.eye(NUM_CLASSES)[np.random.choice(NUM_CLASSES, NUM_SAMPLES)]

image_dataset = tf.data.Dataset.from_tensor_slices(images)
label_dataset = tf.data.Dataset.from_tensor_slices(labels)

eval_dataset = tf.data.Dataset.zip((image_dataset, label_dataset)).batch(32)
unlabeled_dataset = eval_dataset.map(lambda image, label: image)

Prepare forward pass callback

# NOTE: In the actual use cases, the users should implement this part to serve
#       their own goals if necessary.
def forward_pass_callback(model: tf.keras.Model, _: Any = None) -> None:
    """
    NOTE: This is intended to be the user-defined model calibration function.
    AIMET requires the above signature. So if the user's calibration function does not
    match this signature, please create a simple wrapper around this callback function.

    A callback function for model calibration that simply runs forward passes on the model to
    compute encoding (delta/offset). This callback function should use representative data and should
    be subset of entire train/validation dataset (~1000 images/samples).

    :param model: tf.keras.Model object.
    :param _: Argument(s) of this callback function. Up to the user to determine the type of this parameter.
    E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of
    parameters or an object representing something more complex.
    """
    # User action required
    # User should create data loader/iterable using representative dataset and simply run
    # forward passes on the model.
    _ = model.predict(unlabeled_dataset)

Prepare eval callback

# NOTE: In the actual use cases, the users should implement this part to serve
#       their own goals if necessary.
def eval_callback(model: tf.keras.Model, _: Any = None) -> float:
    """
    NOTE: This is intended to be the user-defined model evaluation function.
    AIMET requires the above signature. So if the user's calibration function does not
    match this signature, please create a simple wrapper around this callback function.

    A callback function for model evaluation that determines model performance. This callback function is
    expected to return scalar value representing the model performance evaluated against entire
    test/evaluation dataset.

    :param model: tf.keras.Model object.
    :param _: Argument(s) of this callback function. Up to the user to determine the type of this parameter.
    E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of
    parameters or an object representing something more complex.
    :return: Scalar value representing the model performance.
    """
    # User action required
    # User should create data loader/iterable using entire test/evaluation dataset, perform forward passes on
    # the model and return single scalar value representing the model performance.

    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.CategoricalCrossentropy(),
                  metrics=tf.keras.metrics.CategoricalAccuracy())

    _, acc = model.evaluate(eval_dataset)
    return acc

Prepare model

    model = tf.keras.applications.ResNet50()
    prepared_model = prepare_model(model)

Create QuantAnalyzer object

    quant_analyzer = QuantAnalyzer(model=prepared_model,
                                   forward_pass_callback=forward_pass_callback_fn,
                                   eval_callback=eval_callback_fn)

    # Approximately 256 images/samples are recommended for MSE loss analysis. So, if the dataset
    # has batch_size of 64, then 4 number of batches leads to 256 images/samples.
    quant_analyzer.enable_per_layer_mse_loss(unlabeled_dataset=unlabeled_dataset, num_batches=4)

Run QuantAnalyzer

    quant_analyzer.analyze(quant_scheme=QuantScheme.post_training_tf_enhanced,
                           default_param_bw=8,
                           default_output_bw=8,
                           config_file=None,
                           results_dir="./quant_analyzer_results/")