AIMET TensorFlow AutoQuant API¶
Examples Notebook Link¶
For an end-to-end notebook showing how to use TensorFlow AutoQuant, please see here.
Top-level API¶
-
class
aimet_tensorflow.auto_quant.
AutoQuant
(allowed_accuracy_drop, unlabeled_dataset, eval_callback, default_param_bw=8, default_output_bw=8, default_quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, default_rounding_mode='nearest', default_config_file=None)[source]¶ Integrate and apply post-training quantization techniques.
AutoQuant includes 1) batchnorm folding, 2) cross-layer equalization, and 3) Adaround. These techniques will be applied in a best-effort manner until the model meets the evaluation goal given as allowed_accuracy_drop.
- Parameters
allowed_accuracy_drop (
float
) – Maximum allowed accuracy drop.unlabeled_dataset (
DatasetV1
) – An unlabeled dataset for encoding computation. By default, this dataset will be also used for Adaround unless otherwise specified by self.set_adaround_params.eval_callback (
Callable
[[Session
,Optional
[int
]],float
]) – A function that maps a tf session and the number of samples to the evaluation score. This callback is expected to return a scalar value representing the model performance evaluated against exactly N samples, where N is the number of samples passed as the second argument of this callback. NOTE: If N is None, the model is expected to be evaluated against the whole evaluation dataset.default_param_bw (
int
) – Default bitwidth (4-31) to use for quantizing layer parameters.default_output_bw (
int
) – Default bitwidth (4-31) to use for quantizing layer inputs andoutputs.default_quant_scheme (
QuantScheme
) – Quantization scheme. Supported values are QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced.default_rounding_mode (
str
) – Rounding mode. Supported options are ‘nearest’ or ‘stochastic’default_config_file (
Optional
[str
]) – Path to configuration file for model quantizers
-
apply
(fp32_sess, starting_op_names, output_op_names, results_dir='/tmp', cache_id=None)[source]¶ Apply post-training quantization techniques.
- Parameters
fp32_sess (
Session
) – tf.Session associated with the model to apply PTQ techniques.starting_op_names (
List
[str
]) – List of starting op names of the model.output_op_names (
List
[str
]) – List of output op names of the model.results_dir (
str
) – Directory to save the results.
- Return type
Tuple
[Session
,float
,str
]- Returns
Tuple of (best session, eval score, encoding path).
-
set_adaround_params
(adaround_params)[source]¶ Set Adaround parameters. If this method is not called explicitly by the user, AutoQuant will use unlabeled_dataset (passed to __init__) for Adaround.
- Parameters
adaround_params (
AdaroundParameters
) – Adaround parameters.- Return type
None
Code Examples¶
Required imports
from typing import Optional
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet import ResNet50
from aimet_tensorflow.utils.common import iterate_tf_dataset
from aimet_tensorflow.adaround.adaround_weight import AdaroundParameters
from aimet_tensorflow.auto_quant import AutoQuant
from aimet_tensorflow.utils.graph import update_keras_bn_ops_trainable_flag
tf.compat.v1.disable_eager_execution()
Define constants and helper functions
EVAL_DATASET_SIZE = 5000
CALIBRATION_DATASET_SIZE = 2000
BATCH_SIZE = 100
_sampled_datasets = {}
def _create_sampled_dataset(dataset, num_samples):
if num_samples in _sampled_datasets:
return _sampled_datasets[num_samples]
with dataset._graph.as_default():
SHUFFLE_BUFFER_SIZE = 300 # NOTE: Adjust the buffer size as necessary.
SHUFFLE_SEED = 22222
dataset = dataset.shuffle(buffer_size=SHUFFLE_BUFFER_SIZE, seed=SHUFFLE_SEED)\
.take(num_samples)\
.batch(BATCH_SIZE)
_sampled_datasets[num_samples] = dataset
return dataset
Prepare model and dataset
input_shape = (224, 224, 3)
num_classes = 1000
model = ResNet50(weights='imagenet', input_shape=input_shape)
model = update_keras_bn_ops_trainable_flag(model, False, load_save_path='./')
input_tensor_name = model.input.name
input_op_name, _ = input_tensor_name.split(":")
output_tensor_name = model.output.name
output_op_name, _ = output_tensor_name.split(":")
# NOTE: In the actual use cases, a real dataset should provide by the users.
images = np.random.rand(100, *input_shape)
labels = np.random.randint(num_classes, size=(100,))
image_dataset = tf.compat.v1.data.Dataset.from_tensor_slices(images)\
.repeat()\
.take(EVAL_DATASET_SIZE)
label_dataset = tf.compat.v1.data.Dataset.from_tensor_slices(labels)\
.repeat()\
.take(EVAL_DATASET_SIZE)
eval_dataset = tf.compat.v1.data.Dataset.zip((image_dataset, label_dataset))
Prepare unlabeled dataset
# NOTE: In the actual use cases, the users should implement this part to serve
# their own goals if necessary.
unlabeled_dataset = image_dataset.batch(BATCH_SIZE)
Prepare eval callback
# NOTE: In the actual use cases, the users should implement this part to serve
# their own goals if necessary.
def eval_callback(sess: tf.compat.v1.Session,
num_samples: Optional[int] = None) -> float:
if num_samples is None:
num_samples = EVAL_DATASET_SIZE
sampled_dataset = _create_sampled_dataset(eval_dataset, num_samples)
with sess.graph.as_default():
sess.run(tf.compat.v1.global_variables_initializer())
input_tensor = sess.graph.get_tensor_by_name(input_tensor_name)
output_tensor = sess.graph.get_tensor_by_name(output_tensor_name)
num_correct_predictions = 0
for images, labels in iterate_tf_dataset(sampled_dataset):
prob = sess.run(output_tensor, feed_dict={input_tensor: images})
predictions = np.argmax(prob, axis=1)
num_correct_predictions += np.sum(predictions == labels)
return int(num_correct_predictions) / num_samples
Create AutoQuant object
auto_quant = AutoQuant(allowed_accuracy_drop=0.01,
unlabeled_dataset=unlabeled_dataset,
eval_callback=eval_callback)
(Optional) Set Adaround parameters
For setting the num_batches parameter, use the following guideline. The number of batches is used to evaluate the model while calculating the quantization encodings. Typically we want AdaRound to use around 2000 samples. For example, if the batch size is 32, num_batches is 64. If the batch size you are using is different, adjust the num_batches accordingly.
ADAROUND_DATASET_SIZE = 2000
adaround_dataset = _create_sampled_dataset(image_dataset, ADAROUND_DATASET_SIZE)
adaround_params = AdaroundParameters(adaround_dataset,
num_batches=ADAROUND_DATASET_SIZE // BATCH_SIZE)
auto_quant.set_adaround_params(adaround_params)
Run AutoQuant
sess, accuracy, encoding_path =\
auto_quant.apply(tf.compat.v1.keras.backend.get_session(),
starting_op_names=[input_op_name],
output_op_names=[output_op_name])