Automatic mixed precision

This technique helps choose per-layer integer bit-widths to retain model accuracy when run on fixed-point runtimes like Qualcomm® AI Engine Direct.

As an example, say a particular model is not meeting a desired accuracy target when run in INT8. The Auto Mixed Precision (AMP) feature will find a minimal set of layers that need to run on higher precision, INT16 for example, to get to the desired quantized accuracy.

Choosing a higher precision for some layers necessarily involves a trade-off: lower inferences/sec for higher accuracy and vice-versa. The AMP feature will generate a pareto curve that can guide the user to decide the right operating point for this tradeoff.

Context

For performing AMP, a user needs to start with a PyTorch, TensorFlow or ONNX model and create a Quantization Simulation model QuantizationSimModel. This QuantSim model, along with an allowable accuracy drop, is passed to the API.

The function changes the QuantSim Sim model in place with different quantizers having different bit-widths. This QuantSim model can be either exported or evaluated to get a quantization accuracy.

../../_images/automatic_mixed_precision_1.png

Mixed Precision Algorithm

The algorithm involves 4 phases:

../../_images/automatic_mixed_precision_2.png

1) Find layer groups

Layer Groups are defined as a group of layers grouped together based on certain rules. This helps in reducing search space over which the mixed precision algorithm operates. It also ensures that we search only over the valid bit-width settings for parameters and activations.

../../_images/automatic_mixed_precision_3.png

2) Perform sensitivity analysis (Phase 1)

In this phase the algorithm performs a per-layer group sensitivity analysis. This will identify how sensitive is the model if we choose a lower quantization bit-width for a particular layer group. The sensitivity analysis yields an accuracy list which is cached and can be re-used again by the algorithm.

Below is an example of a list generated using sensitivity analysis:

../../_images/accuracy_list.png

3) Create a Pareto-front list (Phase 2)

A Pareto curve is a trade-off curve that describes how accuracy varies given a bit-ops target and vice versa. The AMP algorithm yields a Pareto front curve which consists of layer groups changed up to that point, relative bit-ops (relative to starting bit-ops), accuracy of the model, and the bit-width to which the layer group was changed to.

An example of a Pareto list:

../../_images/pareto.png

Bit-ops are computed as

\(Bit-ops = Mac(op) * Bitwidth(parameter) * Bitwidth(Activation)\)

The Pareto list can be used for plotting a Pareto curve. A Bokeh plot for Pareto curve is generated and saved in the results directory.

../../_images/pareto_curve.png

Note

A user can pass two different evaluation callbacks for phase 1 and phase 2. Since phase 1 is measuring sensitivity of each quantizer group, we can pass a smaller representative dataset for phase 1 for evaluation, or even use an indirect measure such as SQNR which can be computed faster than but correlates well with the real evaluation metric.

It is recommended to use the complete dataset for evaluation in phase 2.

4) Reduce Bit-width Convert Op Overhead (Phase 3)

Convert Ops are introduced in the mixed-precision model for transition between Ops that are assigned different activation bit-widths or data types (float vs int). These Convert Ops contribute to the inference time along with bit-operations of Ops. In this phase the algorithm derives a mixed-precision solution having less Convert Op overhead w.r.t. to original solution keeping the mixed-precision accuracy intact. The algorithm produces mixed-precision solutions for a range of alpha values (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) where the alpha represents fraction of original Convert Op overhead allowed for respective solution.

Use Cases

  1. Choosing a very high accuracy drop (equivalent to setting allowed_accuracy_drop as None):

AIMET allows a user to save intermediate states for computation of the Pareto list. Therefore, if a user computes a Pareto list corresponding to an accuracy drop of None, they can view the complete profile of how model accuracy will vary as bit-ops vary.

Thereafter, a user can visualize the Pareto curve plot and choose an optimal point for accuracy. The algorithm can be re-run with the new accuracy drop to get a sim model with the required accuracy.

Note

The Pareto list is not modified during the second run.

  1. Choosing a lower accuracy drop and then continuing to compute pareto list from this point if more accuracy drop is acceptable:

To enable this a user can use the clean_start parameter in the API. If clean_start is set to False then the Pareto list will start computation from the last point where it left off.

Note

  • It is recommended to set the clean_start parameter to False to use cached results for both use cases.

  • If the model or candidate bit-widths change, the user needs to do a clean start.

Workflow

Code example

Step 1

Required imports

import torch
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
from aimet_torch.batch_norm_fold import fold_all_batch_norms
from aimet_common.defs import QuantizationDataType, CallbackFunc
from aimet_torch.v1.mixed_precision import choose_mixed_precision
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.amp.mixed_precision_algo import GreedyMixedPrecisionAlgo

Load the model, define forward_pass and evaluation callbacks

# General setup that can be changed as needed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = torchvision.models.mobilenet_v2(pretrained=True).eval().to(device)

batch_size = 64
PATH_TO_IMAGENET = ...
data = torchvision.datasets.ImageNet(PATH_TO_IMAGENET, split="train")
data_loader = DataLoader(data, batch_size=batch_size)

dummy_input = torch.randn(1, 3, 224, 224).to(device)
fold_all_batch_norms(model, dummy_input.shape)

# Callback function to pass calibration data through the model
def forward_pass(model: torch.nn.Module, batches):
    with torch.no_grad():
        for batch, (images, _) in enumerate(data_loader):
            images = images.to(device)
            model(images)
            if batch >= batches:
                break

# Basic ImageNet evaluation function
def evaluate(model, data_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, labels in tqdm(data_loader):
            data, labels = data.to(device), labels.to(device)
            logits = model(data)
            correct += (logits.argmax(1) == labels).type(torch.float).sum().item()
    accuracy = correct / len(data_loader.dataset)
    return accuracy

Required imports

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import random
import numpy as np
from tensorflow.keras.applications.resnet import ResNet50, preprocess_input, decode_predictions

from aimet_tensorflow.keras.quantsim import QuantizationSimModel
from aimet_common.defs import CallbackFunc, QuantizationDataType, QuantScheme
from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms
from aimet_tensorflow.keras.mixed_precision import choose_mixed_precision
from aimet_tensorflow.keras.amp.mixed_precision_algo import GreedyMixedPrecisionAlgo

Load the model, define forward_pass and evaluation callbacks

# Load the model
model = ResNet50(weights="imagenet")

# Perform batch norm folding
_, model = fold_all_batch_norms(model)

def center_crop(image):
    """
    Perform the center corp on the images.
    :param image: List of images as tensors which we need to center corp. Expects the image size of 256 x 256.
    :return: Center corped images of size 224 x 224
    """

    img_height = 256
    img_width = 256
    crop_length = 224
    start_x = (img_height - crop_length) // 2
    start_y = (img_width - crop_length) // 2
    cropped_image = image[:,  start_x:(img_width - start_x), start_y:(img_height - start_y), :]
    return cropped_image

def get_eval_func(dataset_dir, batch_size, num_iterations=50000):
    """
    Helper Function returns an evaluation function which performs the forward pass on the specified model
     with given dataset parameters
    :param dataset_dir: Directrory from where the dataset images needs to be loaded.
    :param batch_size: Batch size to be used in dataloader
    :param num_iterations: Optional parameter stating total number of images to be used.
    Default set to 50000, which is size of the validation set of imagenet dataset.
    :return: returns a evaluation function which can be used to evaluate the model's accuracy on the preset dataset.
    """

    def func_wrapper(model, iterations):
        """ Evaluation Function which is return from the parent function. Performs the forward pass on the model with the given dataset and retuerns the acuracy."""

        validation_ds = tf.keras.preprocessing.image_dataset_from_directory(
            directory=dataset_dir,
            labels='inferred',
            label_mode='categorical',
            batch_size=batch_size,
            shuffle=False)

        # If no iterations specified, set to full validation set
        if not iterations:
            iterations = num_iterations
        else:
            iterations = iterations * batch_size
        top1 = 0
        total = 0
        for (img, label) in validation_ds:
            img = center_crop(img)
            x = preprocess_input(img)
            preds = model.predict(x,batch_size = batch_size)
            label = np.where(label)[1]
            label = [validation_ds.class_names[int(i)] for i in label]
            cnt = sum([1 for a, b in zip(label, decode_predictions(preds, top=1)) if str(a) == b[0][0]])
            top1 += cnt
            total += len(label)
            if total >= iterations:
                break
        return top1/total
    return func_wrapper

def get_data_loader_wrapper(dataset_dir, batch_size, is_training=False):
    """
    Helper function which returns a method calling which will give a data loader.
    :param dataset_dir: Directrory from where the dataset images needs to be loaded.
    :param batch_size: Batch size to be used in dataloader
    :param is_training: Default to False. It is used to set the shuffle flag for the data loader.
    :return: Returns a wrapper function which will return a dataloader.
    """
    def dataloader_wrapper():
        dataloader = tf.keras.preprocessing.image_dataset_from_directory(
            directory=dataset_dir,
            labels='inferred',
            label_mode='categorical',
            batch_size=batch_size,
            shuffle = is_training,
            image_size=(256, 256))

        return dataloader.map(lambda x, y: preprocess_input(center_crop(x)))

    return dataloader_wrapper

# get the evaluation function
# We will use this function to for forward pass callback as well.
batch_size = 32
dataset_dir = ... # path to dataset directory.
eval_func = get_eval_func(dataset_dir, batch_size)

# Calculate the Original Model accuracy
org_top1 = eval_func(model, None)
print("Original Model Accuracy: ", org_top1)

Required imports

import math
import os
import numpy as np
import onnx
import onnxsim
import torch
from datasets import load_dataset
from torchvision import transforms
from torchvision.models import MobileNet_V2_Weights, mobilenet_v2
from tqdm import tqdm
from aimet_onnx.defs import DataLoader
from aimet_onnx.quantsim import QuantizationSimModel
from aimet_common.defs import QuantizationDataType, CallbackFunc
from aimet_onnx.mixed_precision import choose_mixed_precision

Instantiate a PyTorch model, convert to ONNX graph, define forward_pass and evaluation callbacks

pt_model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)

# Modify file_path as you wish, we are using temporary directory for now
file_path = os.path.join('/tmp', f'mobilenet_v2.onnx')
torch.onnx.export(
    pt_model,
    (dummy_input,),
    file_path,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'},
    },
)

# Load exported ONNX model
model = onnx.load_model(file_path)
# End of loading the model

# Prepare model with onnx-simplifier
try:
    model, _ = onnxsim.simplify(model)
except:
    print('ONNX Simplifier failed. Proceeding with unsimplified model')
# End of prepare model

# Set up dataloader
dataset = load_dataset(
    'ILSVRC/imagenet-1k',
    split='validation',
).shuffle()

class CustomDataLoader(DataLoader):
    def __init__(
        self,
        data: np.ndarray,
        batch_size: int,
        iterations: int,
        unlabeled: bool = True,
    ):
        super().__init__(data, batch_size, iterations)
        self._current_iteration = 0
        self._unlabeled = unlabeled

    def __iter__(self):
        self._current_iteration = 0
        return self

    def __next__(self):
        if self._current_iteration < self.iterations:
            start = self._current_iteration * self.batch_size
            end = start + self.batch_size
            self._current_iteration += 1

            batch_data = self._data[start:end]
            if self._unlabeled:
                return np.stack(batch_data['image'])
            else:
                return np.stack(batch_data['image']), np.stack(batch_data['label'])
        else:
            raise StopIteration


preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


def transforms(examples):
    examples['image'] = [
        preprocess(image.convert('RGB')) for image in examples['image']
    ]
    return examples


dataset.set_transform(transforms)

BATCH_SIZE = 32
NUM_CALIBRATION_SAMPLES = 1024
NUM_EVAL_SAMPLES = 50000
unlabeled_data_loader = CustomDataLoader(
    dataset, BATCH_SIZE, math.ceil(NUM_CALIBRATION_SAMPLES / BATCH_SIZE)
)
eval_data_loader = CustomDataLoader(
    dataset, BATCH_SIZE, math.ceil(NUM_EVAL_SAMPLES / BATCH_SIZE), unlabeled=False
)
# End of setting up dataloader

def forward_pass(session, _):
    input_name = session.get_inputs()[0].name
    for inputs in tqdm(unlabeled_data_loader):
        session.run(None, {input_name: inputs})

def evaluate(session, _):
    correct_predictions = 0
    total_samples = 0
    for inputs, labels in tqdm(eval_data_loader):
        input_name = sim.session.get_inputs()[0].name
        pred_probs, *_ = sim.session.run(None, {input_name: inputs})
        pred_labels = np.argmax(pred_probs, axis=1)
        correct_predictions += np.sum(pred_labels == labels)
        total_samples += labels.shape[0]

    accuracy = correct_predictions / total_samples
    return accuracy

Step 2

Quantization with mixed precision

default_bitwidth = 16

# ((activation bitwidth, activation data type), (param bitwidth, param data type))
candidates = [((16, QuantizationDataType.int), (16, QuantizationDataType.int)),
             ((16, QuantizationDataType.int), (8, QuantizationDataType.int)),
             ((8, QuantizationDataType.int), (16, QuantizationDataType.int))]
# Allowed accuracy drop in absolute value
allowed_accuracy_drop = 0.5 # Implies 50% drop

eval_callback_for_phase_1 = CallbackFunc(evaluate, func_callback_args=data_loader)
eval_callback_for_phase_2 = CallbackFunc(evaluate, func_callback_args=data_loader)

calibration_batches = 10
forward_pass_call_back = CallbackFunc(forward_pass, func_callback_args=calibration_batches)

# Create quant sim
sim = QuantizationSimModel(model,
                           default_param_bw=default_bitwidth,
                           default_output_bw=default_bitwidth,
                           dummy_input=dummy_input)
sim.compute_encodings(forward_pass, forward_pass_callback_args=calibration_batches)

# Enable phase-3 (optional)
GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True

# Call the mixed precision algo with clean start = True i.e. new accuracy list and pareto list will be generated
# If set to False then pareto front list and accuracy list will be loaded from the provided directory path
# A allowed_accuracy_drop can be specified to export the final model with reference to the pareto list
pareto_front_list = choose_mixed_precision(sim, dummy_input, candidates, eval_callback_for_phase_1,
                                           eval_callback_for_phase_2, allowed_accuracy_drop, results_dir='./data',
                                           clean_start=True, forward_pass_callback=forward_pass_call_back)
print(pareto_front_list)

# Set clean_start to False to start from an existing cache
# Set allowed_accuracy_drop to 0.9 to export the 90% drop point in pareto list
allowed_accuracy_drop = 0.9
pareto_front_list = choose_mixed_precision(sim, dummy_input, candidates, eval_callback_for_phase_1,
                                           eval_callback_for_phase_2, allowed_accuracy_drop, results_dir='./data',
                                           clean_start=False, forward_pass_callback=forward_pass_call_back)
print(pareto_front_list)
sim.export("./data", str(allowed_accuracy_drop), dummy_input)

Quantization with regular mixed precision

default_bitwidth = 16
# Set the candidates for the mixed precision algorithm
# Candidate format given below
# ((activation bitwidth, activation data type), (param bitwidth, param data type))
# e.g. ((16, QuantizationDataType.int), (16, QuantizationDataType.int)),
candidate = [((16, QuantizationDataType.int), (8, QuantizationDataType.int)),
             ((8, QuantizationDataType.int), (8, QuantizationDataType.int))]

# get the quantized model object
sim = QuantizationSimModel(model=model,
                           default_output_bw=default_bitwidth,
                           default_param_bw=default_bitwidth,)

sim.compute_encodings(eval_func, forward_pass_callback_args=500)


# The allowed accuracy drop represents the amount of accuracy drop we are accepting
# to trade for a lower precision, faster model.
# 0.09 represents we are accepting upto 9% accuracy drop from the baseline.
allowed_accuracy_drop = 0.09

eval_callback = CallbackFunc(eval_func, None)
forward_pass_callback = CallbackFunc(eval_func, 500)

# Enable phase-3 (optional)
GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True
# Note: supported candidates ((8,int), (8,int)) & ((16,int), (8,int))

# Call the mixed precision wrapper with appropriate parameters
pareto_front_list = choose_mixed_precision(sim, candidate, eval_callback, eval_callback, allowed_accuracy_drop, "./cmp_res",
                                           clean_start=True, forward_pass_callback=forward_pass_callback)

print("Mixed Precision Model Accuracy: ", eval_func(sim.model, None))
sim.export(filename_prefix='mixed_preision_quant_model', path='.')

Quantization with mixed precision

# Define parameters to pass to mixed precision algo
default_bitwidth = 16

# ((activation bitwidth, activation data type), (param bitwidth, param data type))
candidates = [((16, QuantizationDataType.int), (16, QuantizationDataType.int)),
             ((16, QuantizationDataType.int), (8, QuantizationDataType.int)),
             ((8, QuantizationDataType.int), (16, QuantizationDataType.int))]
# Allowed accuracy drop in absolute value
allowed_accuracy_drop = 0.5 # Implies 50% drop

eval_callback_for_phase_1 = CallbackFunc(evaluate, func_callback_args=None)
eval_callback_for_phase_2 = CallbackFunc(evaluate, func_callback_args=None)

forward_pass_callback = CallbackFunc(forward_pass, func_callback_args=None)

# Create quant sim
sim = QuantizationSimModel(model, default_param_bw=default_bitwidth, default_activation_bw=default_bitwidth)
sim.compute_encodings(forward_pass_callback, forward_pass_callback_args=None)

# Call the mixed precision algo with clean start = True i.e. new accuracy list and pareto list will be generated
# If set to False then pareto front list and accuracy list will be loaded from the provided directory path
# A allowed_accuracy_drop can be specified to export the final model with reference to the pareto list
pareto_front_list = choose_mixed_precision(sim, candidates, eval_callback_for_phase_1,
                                           eval_callback_for_phase_2, allowed_accuracy_drop, results_dir='./data',
                                           clean_start=True, forward_pass_callback=forward_pass_callback)
print(pareto_front_list)

# Set clean_start to False to start from an existing cache
# Set allowed_accuracy_drop to 0.9 to export the 90% drop point in pareto list
allowed_accuracy_drop = 0.9
pareto_front_list = choose_mixed_precision(sim, candidates, eval_callback_for_phase_1,
                                           eval_callback_for_phase_2, allowed_accuracy_drop, results_dir='./data',
                                           clean_start=False, forward_pass_callback=forward_pass_callback)
print(pareto_front_list)
sim.export("./data", str(allowed_accuracy_drop))

API

Top-level API for Automatic mixed precision

aimet_torch.mixed_precision.choose_mixed_precision(sim, *args, **kwargs)[source]

Note

To enable phase-3 set the attribute GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True

Currently only two candidates are supported - ((8,int), (8,int)) & ((16,int), (8,int))

Quantizer Groups definition

class aimet_torch.amp.quantizer_groups.QuantizerGroup(input_quantizers=<factory>, output_quantizers=<factory>, parameter_quantizers=<factory>, supported_kernel_ops=<factory>)[source]

Group of modules and quantizers

get_active_quantizers(name_to_quantizer_dict)[source]

Find all active tensor quantizers associated with this quantizer group

get_candidate(name_to_quantizer_dict)[source]

Gets Activation & parameter bitwidth :type name_to_quantizer_dict: Dict :param name_to_quantizer_dict: Gets module from module name :rtype: Tuple[Tuple[int, QuantizationDataType], Tuple[int, QuantizationDataType]] :return: Tuple of Activation, parameter bitwidth and data type

get_input_quantizer_modules()[source]

helper method to get the module names corresponding to input_quantizers

set_quantizers_to_candidate(name_to_quantizer_dict, candidate)[source]

Sets a quantizer group to a given candidate bitwidth :type name_to_quantizer_dict: Dict :param name_to_quantizer_dict: Gets module from module name :type candidate: Tuple[Tuple[int, QuantizationDataType], Tuple[int, QuantizationDataType]] :param candidate: candidate with act and param bw and data types

Return type:

None

to_list()[source]

Converts quantizer group to a list :rtype: List[Tuple[str, str]] :return: List containing input/output quantizers & weight quantizers

CallbackFunc Definition

class aimet_common.defs.CallbackFunc(func, func_callback_args=None)[source]

Class encapsulating call back function and it’s arguments

Parameters:
  • func (Callable) – Callable Function

  • func_callback_args – Arguments passed to the callable function

class aimet_torch.amp.mixed_precision_algo.EvalCallbackFactory(data_loader, forward_fn=None)[source]

Factory class for various built-in eval callbacks

Parameters:
  • data_loader (DataLoader) – Data loader to be used for evaluation

  • forward_fn (Optional[Callable[[Module, Any], Tensor]]) – Function that runs forward pass and returns the output tensor. This function is expected to take 1) a model and 2) a single batch yielded from the data loader, and return a single torch.Tensor object which represents the output of the model. The default forward function is roughly equivalent to lambda model, batch: model(batch)

sqnr(num_samples=128)[source]

Returns SQNR eval callback.

Parameters:

num_samples (int) – Number of samples used for evaluation

Return type:

CallbackFunc

Returns:

A callback function that evaluates the input model’s SQNR between fp32 outputs and fake-quantized outputs

Top-level API for Regular AMP

Top-level API for Fast AMP (AMP 2.0)

Note

To enable phase-3 set the attribute GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True

Currently only two candidates are supported - ((8,int), (8,int)) & ((16,int), (8,int))

Quantizer Groups definition

CallbackFunc Definition

class aimet_common.defs.CallbackFunc(func, func_callback_args=None)[source]

Class encapsulating call back function and it’s arguments

Parameters:
  • func (Callable) – Callable Function

  • func_callback_args – Arguments passed to the callable function

Top-level API

Note

It is recommended to use onnx-simplifier before applying mixed-precision.

Quantizer Groups definition

CallbackFunc Definition

class aimet_common.defs.CallbackFunc(func, func_callback_args=None)[source]

Class encapsulating call back function and it’s arguments

Parameters:
  • func (Callable) – Callable Function

  • func_callback_args – Arguments passed to the callable function