AIMET TensorFlow Mixed Precision API

Top-level API for Regular AMP

aimet_tensorflow.keras.mixed_precision.choose_mixed_precision(sim, candidates, eval_callback_for_phase1, eval_callback_for_phase2, allowed_accuracy_drop, results_dir, clean_start, forward_pass_callback, amp_search_algo=AMPSearchAlgo.Binary, phase1_optimize=True)[source]

High-level API to perform in place Mixed Precision evaluation on the given sim model. A pareto list is created and a curve for Accuracy vs BitOps is saved under the results directory

Parameters:
  • sim (QuantizationSimModel) – Quantized sim model

  • input_shape – tuple or list of tuples of input shape to the model

  • starting_op_names – List of starting op names of the model

  • output_op_names – List of output op names of the model

  • candidates (List[Tuple[Tuple[int, QuantizationDataType], Tuple[int, QuantizationDataType]]]) –

    List of tuples for all possible bitwidth values for activations and parameters Suppose the possible combinations are- ((Activation bitwidth - 8, Activation data type - int), (Parameter bitwidth - 16, parameter data type - int)) ((Activation bitwidth - 16, Activation data type - float), (Parameter bitwidth - 16, parameter data type - float)) candidates will be [((8, QuantizationDataType.int), (16, QuantizationDataType.int)),

    ((16, QuantizationDataType.float), (16, QuantizationDataType.float))]

  • eval_callback_for_phase1 (CallbackFunc) – An object of CallbackFunc class which takes in Eval function (callable) and eval function parameters. This evaluation callback used to measure sensitivity of each quantizer group during phase 1. The phase 1 involves finding accuracy list/sensitivity of each module. Therefore, a user might want to run the phase 1 with a smaller dataset

  • eval_callback_for_phase2 (CallbackFunc) – An object of CallbackFunc class which takes in Eval function (callable) and eval function parameters. Evaluation callback used to get accuracy of quantized model for phase 2 calculations. The phase 2 involves finding pareto front curve

  • allowed_accuracy_drop (Optional[float]) – Maximum allowed drop in accuracy from FP32 baseline. The pareto front curve is plotted only till the point where the allowable accuracy drop is met. To get a complete plot for picking points on the curve, the user can set the allowable accuracy drop to None.

  • results_dir (str) – Path to save results and cache intermediate results

  • clean_start (bool) – If true, any cached information from previous runs will be deleted prior to starting the mixed-precision analysis. If false, prior cached information will be used if applicable. Note it is the user’s responsibility to set this flag to true if anything in the model or quantization parameters changes compared to the previous run.

  • forward_pass_callback (CallbackFunc) – An object of CallbackFunc class which takes in Forward pass function (callable) and its function parameters. Forward pass callback used to compute quantization encodings

  • amp_search_algo (AMPSearchAlgo) – A valid value from the Enum AMPSearchAlgo. Defines the search algorithm to be used for the phase 2 of AMP. Default to BruteForce for regular AMP.

  • phase1_optimize (bool) – If user set this parameter to false then phase1 default logic will be executed else optimized logic will be executed.

Return type:

Optional[List[Tuple[int, float, QuantizerGroup, int]]]

Returns:

Pareto front list containing a list of (Relative bit ops wrt baseline candidate, eval score, quantizer group and the candidate being used in each step). The Pareto front list can be used for plotting a pareto front curve which provides information regarding how bit ops vary w.r.t. accuracy. If the allowable accuracy drop is set to 100% then a user can use the pareto front curve to pick points and re-run, None if we early exit the mixed precision algorithm.


Top-level API for Fast AMP (AMP 2.0)

aimet_tensorflow.keras.mixed_precision.choose_fast_mixed_precision(sim, candidates, data_loader_wrapper, eval_callback_for_phase2, allowed_accuracy_drop, results_dir, clean_start, forward_pass_callback, forward_pass_callback_2=None, amp_search_algo=AMPSearchAlgo.Binary, phase1_optimize=True)[source]

High-level API to perform in place Mixed Precision evaluation on the given sim model. A pareto list is created and a curve for Accuracy vs BitOps is saved under the results directory

Parameters:
  • sim (QuantizationSimModel) – Quantized sim model

  • candidates (List[Tuple[Tuple[int, QuantizationDataType], Tuple[int, QuantizationDataType]]]) –

    List of tuples for all possible bitwidth values for activations and parameters Suppose the possible combinations are- ((Activation bitwidth - 8, Activation data type - int), (Parameter bitwidth - 16, parameter data type - int)) ((Activation bitwidth - 16, Activation data type - float), (Parameter bitwidth - 16, parameter data type - float)) candidates will be [((8, QuantizationDataType.int), (16, QuantizationDataType.int)),

    ((16, QuantizationDataType.float), (16, QuantizationDataType.float))]

  • data_loader_wrapper (Callable) – A Callable function which when called should return a dataloader to be used to do phase 1 forward pass.

  • eval_callback_for_phase2 (CallbackFunc) – An object of CallbackFunc class which takes in Eval function (callable) and eval function parameters. Evaluation callback used to get accuracy of quantized model for phase 2 calculations. The phase 2 involves finding pareto front curve

  • allowed_accuracy_drop (Optional[float]) – Maximum allowed drop in accuracy from FP32 baseline. The pareto front curve is plotted only till the point where the allowable accuracy drop is met. To get a complete plot for picking points on the curve, the user can set the allowable accuracy drop to None.

  • results_dir (str) – Path to save results and cache intermediate results

  • clean_start (bool) – If true, any cached information from previous runs will be deleted prior to starting the mixed-precision analysis. If false, prior cached information will be used if applicable. Note it is the user’s responsibility to set this flag to true if anything in the model or quantization parameters changes compared to the previous run.

  • forward_pass_callback (CallbackFunc) – An object of CallbackFunc class which takes in Forward pass function (callable) and its function parameters. Forward pass callback used to compute quantization encodings

  • forward_pass_callback_2 (Optional[Callable]) – forward pass callback function which will take an input model and inputs and perform forward pass on it and return the output nupy ndarray of the last layer. Can be kept None if the model works with the standard model.predict() forward pass

  • amp_search_algo (AMPSearchAlgo) – A valid value from the Enum AMPSearchAlgo. Defines the search algorithm to be used for the phase 2 of AMP. Default to Interpolation for fast AMP.

  • phase1_optimize (bool) – If user set this parameter to false then phase1 default logic will be executed else optimized logic will be executed.

Return type:

Optional[List[Tuple[int, float, QuantizerGroup, int]]]

Returns:

Pareto front list containing a list of (Relative bit ops wrt baseline candidate, eval score, quantizer group and the candidate being used in each step). The Pareto front list can be used for plotting a pareto front curve which provides information regarding how bit ops vary w.r.t. accuracy. If the allowable accuracy drop is set to 100% then a user can use the pareto front curve to pick points and re-run, None if we early exit the mixed precision algorithm.


Note: To enable phase-3 set the attribute GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True

Currently only two candidates are supported - ((8,int), (8,int)) & ((16,int), (8,int))


Quantizer Groups definition


CallbackFunc Definition

class aimet_common.defs.CallbackFunc(func, func_callback_args=None)[source]

Class encapsulating call back function and it’s arguments

Parameters:
  • func (Callable) – Callable Function

  • func_callback_args – Arguments passed to the callable function

Code Examples

Required imports

#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================

# pylint: skip-file

""" Keras Mixed precision code example to be used for documentation generation. """
# Start of import statements
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

import random
import numpy as np

# imports specific to resnet50 pretrained model
from tensorflow.keras.applications.resnet import ResNet50, preprocess_input, decode_predictions

# AIMET imports

Load Resnet50 model

def get_model():
    """Helper function to return the model"""
    model = ResNet50(
        input_shape=None,
        alpha=1.0,
        include_top=True,
        weights="imagenet",
        input_tensor=None,
        pooling=None,
        classes=1000)
    return model

Eval function

def get_eval_func(dataset_dir, batch_size, num_iterations=50000):
    """
    Helper Function returns an evaluation function which performs the forward pass on the specified model
     with given dataset parameters
    :param dataset_dir: Directrory from where the dataset images needs to be loaded.
    :param batch_size: Batch size to be used in dataloader
    :param num_iterations: Optional parameter stating total number of images to be used.
    Default set to 50000, which is size of the validation set of imagenet dataset.
    :return: returns a evaluation function which can be used to evaluate the model's accuracy on the preset dataset.
    """

    def func_wrapper(model, iterations):
        """ Evaluation Function which is return from the parent function. Performs the forward pass on the model with the given dataset and retuerns the acuracy."""

        validation_ds = tf.keras.preprocessing.image_dataset_from_directory(
            directory=dataset_dir,
            labels='inferred',
            label_mode='categorical',
            batch_size=batch_size,
            shuffle=False)

        # If no iterations specified, set to full validation set
        if not iterations:
            iterations = num_iterations
        else:
            iterations = iterations * batch_size
        top1 = 0
        total = 0
        for (img, label) in validation_ds:
            img = center_crop(img)
            x = preprocess_input(img)
            preds = model.predict(x,batch_size = batch_size)
            label = np.where(label)[1]
            label = [validation_ds.class_names[int(i)] for i in label]
            cnt = sum([1 for a, b in zip(label, decode_predictions(preds, top=1)) if str(a) == b[0][0]])
            top1 += cnt
            total += len(label)
            if total >= iterations:
                break
        return top1/total
    return func_wrapper

Data Loader Wrapper function

def get_data_loader_wrapper(dataset_dir, batch_size, is_training=False):
    """
    Helper function which returns a method calling which will give a data loader.
    :param dataset_dir: Directrory from where the dataset images needs to be loaded.
    :param batch_size: Batch size to be used in dataloader
    :param is_training: Default to False. It is used to set the shuffle flag for the data loader.
    :return: Returns a wrapper function which will return a dataloader.
    """
    def dataloader_wrapper():
        dataloader = tf.keras.preprocessing.image_dataset_from_directory(
            directory=dataset_dir,
            labels='inferred',
            label_mode='categorical',
            batch_size=batch_size,
            shuffle = is_training,
            image_size=(256, 256))

        return dataloader.map(lambda x, y: preprocess_input(center_crop(x)))

    return dataloader_wrapper

Quantization with regular mixed precision

def mixed_precision(dataset_dir):
    """
    Sample function which demonstrates the quantization on a Resnet50 model followed by mixed precision
    """
    np.random.seed(1)
    random.seed(1)
    tf.random.set_seed(1)

    batch_size = 32

    # Load the model
    model = get_model()

    # Perform batch norm folding
    _, model = fold_all_batch_norms(model)

    # get the evalutaion function
    # We will use this function to for forward pass callback as well.
    eval_func = get_eval_func(dataset_dir, batch_size)

    # Calculate the Original Model accuracy
    org_top1 = eval_func(model, None)
    print("Original Model Accuracy: ", org_top1)

    # get the quantized model object
    sim = get_quantizated_model(model, eval_func)

    # Set the candidates for the mixed precision algorithm
    # Candidate format given below
    # ((activation bitwidth, activation data type), (param bitwidth, param data type))
    # e.g. ((16, QuantizationDataType.int), (16, QuantizationDataType.int)),
    candidate = [((16, QuantizationDataType.int), (8, QuantizationDataType.int)),
                 ((8, QuantizationDataType.int), (8, QuantizationDataType.int))]

    # The allowed accuracy drop represents the amount of accuracy drop we are accepting
    # to trade for a lower precision, faster model.
    # 0.09 represents we are accepting upto 9% accuracy drop from the baseline.
    allowed_accuracy_drop = 0.09

    eval_callback = CallbackFunc(eval_func, None)
    forward_pass_call_back = CallbackFunc(eval_func, 500)

    # Enable phase-3 (optional)
    # GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True
    # Note: supported candidates ((8,int), (8,int)) & ((16,int), (8,int))

    # Call the mixed precision wrapper with appropriate parameters
    choose_mixed_precision(sim,  candidate, eval_callback, eval_callback, allowed_accuracy_drop, "./cmp_res", True, forward_pass_call_back )
    print("Mixed Precision Model Accuracy: ", eval_func(sim.model, None))
    sim.export(filename_prefix='mixed_preision_quant_model', path='.')

Quantization with fast mixed precision

def fast_mixed_precision(dataset_dir):
    """
    Sample function which demonstrates the quantization on a Resnet50 model followed by mixed precision using AMP 2.0
    """
    np.random.seed(1)
    random.seed(1)
    tf.random.set_seed(1)

    batch_size = 32

    # Load the model
    model = get_model()

    # Perform batch norm folding
    _ = fold_all_batch_norms(model)

    # get the evalutaion function
    # We will use this function to for forward pass callback as well.
    eval_func = get_eval_func(dataset_dir, batch_size)

    # Calculate the Original Model accuracy
    org_top1 = eval_func(model, None)
    print("Original Model Accuracy: ", org_top1)

    # get the quantized model object
    sim = get_quantizated_model(model, eval_func)

    # Set the candidates for the mixed precision algorithm
    # Candidate format given below
    # ((activation bitwidth, activation data type), (param bitwidth, param data type))
    # e.g. ((16, QuantizationDataType.int), (16, QuantizationDataType.int)),
    candidate = [((16, QuantizationDataType.int), (8, QuantizationDataType.int)),
                 ((8, QuantizationDataType.int), (8, QuantizationDataType.int))]

    # The allowed accuracy drop represents the amount of accuracy drop we are accepting
    # to trade for a lower precision, faster model.
    # 0.09 represents we are accepting upto 9% accuracy drop from the baseline.
    allowed_accuracy_drop = 0.09

    data_loader_wrapper = get_data_loader_wrapper(dataset_dir, batch_size)

    eval_callback = CallbackFunc(eval_func, None)
    forward_pass_call_back = CallbackFunc(eval_func, 500)

    # Enable phase-3 (optional)
    # GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True
    # Note: supported candidates ((8,int), (8,int)) & ((16,int), (8,int))

    # Get the GreedyMixedPrecisionAlgo Object
    choose_fast_mixed_precision(sim,  candidate, data_loader_wrapper, eval_callback, allowed_accuracy_drop, "./cmp_res", True, forward_pass_call_back)
    print("Mixed Precision Model Accuracy: ", eval_func(sim.model, None))
    sim.export(filename_prefix='mixed_preision_quant_model', path='.')