AIMET TensorFlow AdaRound API¶
Examples Notebook Link¶
For an end-to-end notebook showing how to use TensorFlow AdaRound, please see here.
Top-level API¶
-
aimet_tensorflow.adaround.adaround_weight.Adaround.
apply_adaround
(session, starting_op_names, output_op_names, params, path, filename_prefix, default_param_bw=4, default_quant_scheme=<QuantScheme.post_training_tf_enhanced: 2>, default_config_file=None)¶ Returns Tf session - model with optimized weight rounding of every op (Conv and Linear) and also saves the corresponding quantization encodings to a separate JSON-formatted file that can then be imported by QuantSim for inference or QAT
- Parameters
session (
Session
) – Tf session with model to adaroundstarting_op_names (
List
[str
]) – List of starting op names of the modeloutput_op_names (
List
[str
]) – List of output op names of the modelparams (
AdaroundParameters
) – Parameters for adaroundpath (
str
) – path where to store parameter encodingsfilename_prefix (
str
) – Prefix to use for filename of the encodings filedefault_param_bw (
int
) – Default bitwidth (4-31) to use for quantizing layer parameters. Default 4default_quant_scheme (
QuantScheme
) – Quantization scheme. Supported options are QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced. Default QuantScheme.post_training_tf_enhanceddefault_config_file (
Optional
[str
]) – Default configuration file for model quantizers
- Return type
Session
- Returns
Tf session with Adarounded weight and saves corresponding parameter encodings JSON file at provided path
Adaround Parameters¶
-
class
aimet_tensorflow.adaround.adaround_weight.
AdaroundParameters
(data_set, num_batches, default_num_iterations=10000, default_reg_param=0.01, default_beta_range=(20, 2), default_warm_start=0.2)[source]¶ Configuration parameters for Adaround
- Parameters
data_set (
DatasetV2
) – TF Data setnum_batches (
int
) – Number of batchesdefault_num_iterations (
int
) – Number of iterations to adaround each layer. Default 10000default_reg_param (
float
) – Regularization parameter, trading off between rounding loss vs reconstruction loss. Default 0.01default_beta_range (
Tuple
) – Start and stop beta parameter for annealing of rounding loss (start_beta, end_beta). Default (20, 2)default_warm_start (
float
) – warm up period, during which rounding loss has zero effect. Default 20% (0.2)
Enum Definition¶
Quant Scheme Enum
-
class
aimet_common.defs.
QuantScheme
[source]¶ Enumeration of Quant schemes
-
post_training_percentile
= 6¶ For a Tensor, adjusted minimum and maximum values are selected based on the percentile value passed. The Quantization encodings are calculated using the adjusted minimum and maximum value.
-
post_training_tf
= 1¶ For a Tensor, the absolute minimum and maximum value of the Tensor are used to compute the Quantization encodings.
-
post_training_tf_enhanced
= 2¶ For a Tensor, searches and selects the optimal minimum and maximum value that minimizes the Quantization Noise. The Quantization encodings are calculated using the selected minimum and maximum value.
-
training_range_learning_with_tf_enhanced_init
= 4¶ For a Tensor, the encoding values are initialized with the post_training_tf_enhanced scheme. Then, the encodings are learned during training.
-
training_range_learning_with_tf_init
= 3¶ For a Tensor, the encoding values are initialized with the post_training_tf scheme. Then, the encodings are learned during training.
-
Code Examples¶
Required imports
import logging
import numpy as np
import tensorflow as tf
from aimet_common.utils import AimetLogger
from aimet_common.defs import QuantScheme
from aimet_tensorflow.examples.test_models import keras_model
from aimet_tensorflow.quantsim import QuantizationSimModel
from aimet_tensorflow.adaround.adaround_weight import Adaround, AdaroundParameters
Evaluation function
def dummy_forward_pass(session: tf.compat.v1.Session, _):
"""
This is intended to be the user-defined model evaluation function.
AIMET requires the above signature. So if the user's eval function does not
match this signature, please create a simple wrapper.
:param session: Session with model to be evaluated
:param _: These argument(s) are passed to the forward_pass_callback as-is. Up to
the user to determine the type of this parameter. E.g. could be simply an integer representing the number
of data samples to use. Or could be a tuple of parameters or an object representing something more complex.
If set to None, forward_pass_callback will be invoked with no parameters.
:return: single float number (accuracy) representing model's performance
"""
input_data = np.random.rand(32, 16, 16, 3)
input_tensor = session.graph.get_tensor_by_name('conv2d_input:0')
output_tensor = session.graph.get_tensor_by_name('keras_model/Softmax:0')
output = session.run(output_tensor, feed_dict={input_tensor: input_data})
return output
After applying AdaRound to the model, the AdaRounded session and associated encodings are returned
def apply_adaround_example():
AimetLogger.set_level_for_all_areas(logging.DEBUG)
tf.compat.v1.reset_default_graph()
_ = keras_model()
init = tf.compat.v1.global_variables_initializer()
dataset_size = 32
batch_size = 16
possible_batches = dataset_size // batch_size
input_data = np.random.rand(dataset_size, 16, 16, 3)
dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.batch(batch_size=batch_size)
session = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph())
session.run(init)
params = AdaroundParameters(data_set=dataset, num_batches=possible_batches, default_num_iterations=10)
starting_op_names = ['conv2d_input']
output_op_names = ['keras_model/Softmax']
# W4A8
param_bw = 4
output_bw = 8
quant_scheme = QuantScheme.post_training_tf_enhanced
# Returns session with adarounded weights and their corresponding encodings
adarounded_session = Adaround.apply_adaround(session, starting_op_names, output_op_names, params, path='./',
filename_prefix='dummy', default_param_bw=param_bw,
default_quant_scheme=quant_scheme, default_config_file=None)
# Create QuantSim using adarounded_session
sim = QuantizationSimModel(adarounded_session, starting_op_names, output_op_names, quant_scheme,
default_output_bw=output_bw, default_param_bw=param_bw, use_cuda=False)
# Set and freeze encodings to use same quantization grid and then invoke compute encodings
sim.set_and_freeze_param_encodings(encoding_path='./dummy.encodings')
sim.compute_encodings(dummy_forward_pass, None)
session.close()
adarounded_session.close()