AIMET TensorFlow Mixed Precision API
Top-level API for Regular AMP
Top-level API for Fast AMP (AMP 2.0)
Note: To enable phase-3 set the attribute GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True
Currently only two candidates are supported - ((8,int), (8,int)) & ((16,int), (8,int))
Quantizer Groups definition
CallbackFunc Definition
Code Examples
Required imports
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: skip-file
""" Keras Mixed precision code example to be used for documentation generation. """
# Start of import statements
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import random
import numpy as np
# imports specific to resnet50 pretrained model
from tensorflow.keras.applications.resnet import ResNet50, preprocess_input, decode_predictions
# AIMET imports
Load Resnet50 model
def get_model():
"""Helper function to return the model"""
model = ResNet50(
input_shape=None,
alpha=1.0,
include_top=True,
weights="imagenet",
input_tensor=None,
pooling=None,
classes=1000)
return model
Eval function
def get_eval_func(dataset_dir, batch_size, num_iterations=50000):
"""
Helper Function returns an evaluation function which performs the forward pass on the specified model
with given dataset parameters
:param dataset_dir: Directrory from where the dataset images needs to be loaded.
:param batch_size: Batch size to be used in dataloader
:param num_iterations: Optional parameter stating total number of images to be used.
Default set to 50000, which is size of the validation set of imagenet dataset.
:return: returns a evaluation function which can be used to evaluate the model's accuracy on the preset dataset.
"""
def func_wrapper(model, iterations):
""" Evaluation Function which is return from the parent function. Performs the forward pass on the model with the given dataset and retuerns the acuracy."""
validation_ds = tf.keras.preprocessing.image_dataset_from_directory(
directory=dataset_dir,
labels='inferred',
label_mode='categorical',
batch_size=batch_size,
shuffle=False)
# If no iterations specified, set to full validation set
if not iterations:
iterations = num_iterations
else:
iterations = iterations * batch_size
top1 = 0
total = 0
for (img, label) in validation_ds:
img = center_crop(img)
x = preprocess_input(img)
preds = model.predict(x,batch_size = batch_size)
label = np.where(label)[1]
label = [validation_ds.class_names[int(i)] for i in label]
cnt = sum([1 for a, b in zip(label, decode_predictions(preds, top=1)) if str(a) == b[0][0]])
top1 += cnt
total += len(label)
if total >= iterations:
break
return top1/total
return func_wrapper
Data Loader Wrapper function
def get_data_loader_wrapper(dataset_dir, batch_size, is_training=False):
"""
Helper function which returns a method calling which will give a data loader.
:param dataset_dir: Directrory from where the dataset images needs to be loaded.
:param batch_size: Batch size to be used in dataloader
:param is_training: Default to False. It is used to set the shuffle flag for the data loader.
:return: Returns a wrapper function which will return a dataloader.
"""
def dataloader_wrapper():
dataloader = tf.keras.preprocessing.image_dataset_from_directory(
directory=dataset_dir,
labels='inferred',
label_mode='categorical',
batch_size=batch_size,
shuffle = is_training,
image_size=(256, 256))
return dataloader.map(lambda x, y: preprocess_input(center_crop(x)))
return dataloader_wrapper
Quantization with regular mixed precision
def mixed_precision(dataset_dir):
"""
Sample function which demonstrates the quantization on a Resnet50 model followed by mixed precision
"""
np.random.seed(1)
random.seed(1)
tf.random.set_seed(1)
batch_size = 32
# Load the model
model = get_model()
# Perform batch norm folding
_, model = fold_all_batch_norms(model)
# get the evalutaion function
# We will use this function to for forward pass callback as well.
eval_func = get_eval_func(dataset_dir, batch_size)
# Calculate the Original Model accuracy
org_top1 = eval_func(model, None)
print("Original Model Accuracy: ", org_top1)
# get the quantized model object
sim = get_quantizated_model(model, eval_func)
# Set the candidates for the mixed precision algorithm
# Candidate format given below
# ((activation bitwidth, activation data type), (param bitwidth, param data type))
# e.g. ((16, QuantizationDataType.int), (16, QuantizationDataType.int)),
candidate = [((16, QuantizationDataType.int), (8, QuantizationDataType.int)),
((8, QuantizationDataType.int), (8, QuantizationDataType.int))]
# The allowed accuracy drop represents the amount of accuracy drop we are accepting
# to trade for a lower precision, faster model.
# 0.09 represents we are accepting upto 9% accuracy drop from the baseline.
allowed_accuracy_drop = 0.09
eval_callback = CallbackFunc(eval_func, None)
forward_pass_call_back = CallbackFunc(eval_func, 500)
# Enable phase-3 (optional)
# GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True
# Note: supported candidates ((8,int), (8,int)) & ((16,int), (8,int))
# Call the mixed precision wrapper with appropriate parameters
choose_mixed_precision(sim, candidate, eval_callback, eval_callback, allowed_accuracy_drop, "./cmp_res", True, forward_pass_call_back )
print("Mixed Precision Model Accuracy: ", eval_func(sim.model, None))
sim.export(filename_prefix='mixed_preision_quant_model', path='.')
Quantization with fast mixed precision
def fast_mixed_precision(dataset_dir):
"""
Sample function which demonstrates the quantization on a Resnet50 model followed by mixed precision using AMP 2.0
"""
np.random.seed(1)
random.seed(1)
tf.random.set_seed(1)
batch_size = 32
# Load the model
model = get_model()
# Perform batch norm folding
_ = fold_all_batch_norms(model)
# get the evalutaion function
# We will use this function to for forward pass callback as well.
eval_func = get_eval_func(dataset_dir, batch_size)
# Calculate the Original Model accuracy
org_top1 = eval_func(model, None)
print("Original Model Accuracy: ", org_top1)
# get the quantized model object
sim = get_quantizated_model(model, eval_func)
# Set the candidates for the mixed precision algorithm
# Candidate format given below
# ((activation bitwidth, activation data type), (param bitwidth, param data type))
# e.g. ((16, QuantizationDataType.int), (16, QuantizationDataType.int)),
candidate = [((16, QuantizationDataType.int), (8, QuantizationDataType.int)),
((8, QuantizationDataType.int), (8, QuantizationDataType.int))]
# The allowed accuracy drop represents the amount of accuracy drop we are accepting
# to trade for a lower precision, faster model.
# 0.09 represents we are accepting upto 9% accuracy drop from the baseline.
allowed_accuracy_drop = 0.09
data_loader_wrapper = get_data_loader_wrapper(dataset_dir, batch_size)
eval_callback = CallbackFunc(eval_func, None)
forward_pass_call_back = CallbackFunc(eval_func, 500)
# Enable phase-3 (optional)
# GreedyMixedPrecisionAlgo.ENABLE_CONVERT_OP_REDUCTION = True
# Note: supported candidates ((8,int), (8,int)) & ((16,int), (8,int))
# Get the GreedyMixedPrecisionAlgo Object
choose_fast_mixed_precision(sim, candidate, data_loader_wrapper, eval_callback, allowed_accuracy_drop, "./cmp_res", True, forward_pass_call_back)
print("Mixed Precision Model Accuracy: ", eval_func(sim.model, None))
sim.export(filename_prefix='mixed_preision_quant_model', path='.')