AIMET TensorFlow Quantization SIM API¶
Top-level API¶
-
class
aimet_tensorflow.quantsim.
QuantizationSimModel
(session, starting_op_names, output_op_names, quant_scheme='tf_enhanced', rounding_mode='nearest', default_output_bw=8, default_param_bw=8, use_cuda=True, config_file=None)¶ Creates a QuantSim model by adding quantization simulations ops to a given model.
This enables
off-target simulation of inference accuracy
the model to be fine-tuned to counter the effects of quantization
- Parameters
session (
Session
) – The input model as session to add quantize ops tostarting_op_names (
List
[str
]) – List of starting op names of the modeloutput_op_names (
List
[str
]) – List of output op names of the modelquant_scheme (
Union
[str
,QuantScheme
]) – Quantization Scheme, currently supported schemes are post_training_tf and post_training_tf_enhanced, defaults to post_training_tf_enhancedrounding_mode (
str
) – The round scheme to used. One of: ‘nearest’ or ‘stochastic’, defaults to ‘nearest’.default_output_bw (
int
) – bitwidth to use for activation tensors, defaults to 8default_param_bw (
int
) – bitwidth to use for parameter tensors, defaults to 8use_cuda (
bool
) – If True, places quantization ops on GPU. Defaults to Trueconfig_file (
Optional
[str
]) – Path to a config file to use to specify rules for placing quant ops in the model
- Returns
An object which can be used to perform quantization on a tensorflow graph
- Raises
ValueError: An error occurred processing one of the input parameters.
QuantSim simulates the behavior of a Quantized model on Hardware. supports configurations of the scheme, bitwidth for quantization, configuration of hardware, rounding mode to achieve different configurations for simulation. Constructor
- Parameters
model - Model to add simulation ops to
input_shapes - List of input shapes to the model
quant_scheme - Quantization scheme. Supported options for Post Training Quantization are ‘tf_enhanced’ or ‘tf’ or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced. Supported options for Range Learning are QuantScheme.training_range_learning_with_tf_init or QuantScheme.training_range_learning_with_tf_enhanced_init
rounding_mode - Rounding mode. Supported options are ‘nearest’ or ‘stochastic’
default_output_bw - Default bitwidth (4-31) to use for quantizing layer inputs and outputs
default_param_bw - Default bitwidth (4-31) to use for quantizing layer parameters
in_place - If True, then the given ‘model’ is modified in-place to add quant-sim nodes. Only suggested use of this option is when the user wants to avoid creating a copy of the model
config_file - Configuration file for model quantizers
- Note about Quantization SchemesAIMET offers multiple Quantization Schemes-
Post Training Quantization- The encodings of the model are computed using TF or TF-Enhanced scheme
- Trainable Quantization- The min max of encodings are learnt during training
Range Learning with TF initialization - Uses TF scheme to initialize the encodings and then during training these encodings are fine-tuned to improve accuracy of the model
Range Learning with TF-Enhanced initialization - Uses TF-Enhanced scheme to initialize the encodings and then during training these encodings are fine-tuned to improve accuracy of the model
The following API can be used to Compute Encodings for Model
-
QuantizationSimModel.
compute_encodings
(forward_pass_callback, forward_pass_callback_args)¶ Computes encodings for all quantization sim nodes in the model. This is also used to set initial encodings for Range Learning.
- Parameters
forward_pass_callback (
Callable
[[Session
,Any
],None
]) – A callback function that is expected to runs forward passes on a session. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. This callback internally chooses the number of data samples it wants to use for calculating encodings.forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex.
- Returns
None
The following API can be used to Export the Model to target
-
QuantizationSimModel.
export
(path, filename_prefix, orig_sess=None)¶ This method exports out the quant-sim model so it is ready to be run on-target.
Specifically, the following are saved
The sim-model is exported to a regular tensorflow meta/checkpoint without any simulation ops
The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)
- Parameters
path (
str
) – path where to store model pth and encodingsfilename_prefix (
str
) – Prefix to use for filenames of the model pth and encodings filesorig_sess (
Optional
[Session
]) – optional param to pass in original session without quant nodes for export
- Returns
None
Encoding format is described in the Quantization Encoding Specification
Code Example #1 Post Training Quantization¶
Required imports
import tensorflow as tf
# Import the tensorflow quantisim
from aimet_tensorflow import quantsim
from aimet_tensorflow.common import graph_eval
from aimet_tensorflow.utils import graph_saver
from aimet_common.defs import QuantScheme
from tensorflow.examples.tutorials.mnist import input_data
Quantize and fine-tune a trained model
def quantize_model(generator):
tf.compat.v1.reset_default_graph()
# load graph
sess = graph_saver.load_model_from_meta('models/mnist_save.meta', 'models/mnist_save')
def forward_callback(session, iterations):
graph_eval.evaluate_graph(session, generator, ['accuracy'], graph_eval.default_eval_func, iterations)
# Create quantsim model to quantize the network using the default 8 bit params/activations
sim = quantsim.QuantizationSimModel(sess, starting_op_names=['reshape_input'], output_op_names=['dense_1/BiasAdd'],
quant_scheme=QuantScheme.post_training_tf_enhanced,
config_file='../../../TrainingExtensions/common/src/python/aimet_common/'
'quantsim_config/default_config.json')
# Compute encodings
sim.compute_encodings(forward_callback, forward_pass_callback_args=1)
# Do some fine-tuning
training_helper(sim, generator)
Example Fine-tuning step
def training_helper(sim, generator):
"""A Helper function to fine-tune MNIST model"""
g = sim.session.graph
sess = sim.session
with g.as_default():
x = sim.session.graph.get_tensor_by_name("reshape_input:0")
y = g.get_tensor_by_name("labels:0")
fc1_w = g.get_tensor_by_name("dense_1/MatMul/ReadVariableOp:0")
ce = g.get_tensor_by_name("xent:0")
# Using Adam optimizer
train_step = tf.compat.v1.train.AdamOptimizer(1e-3, name="TempAdam").minimize(ce)
graph_eval.initialize_uninitialized_vars(sess)
# Input data for MNIST
mnist = input_data.read_data_sets('./data', one_hot=True)
# Using 100 iterations and batch of size 50
for i in range(100):
batch = mnist.train.next_batch(50)
sess.run([train_step, fc1_w], feed_dict={x: batch[0], y: batch[1]})
if i % 10 == 0:
# Find accuracy of model every 10 iterations
perf = graph_eval.evaluate_graph(sess, generator, ['accuracy'], graph_eval.default_eval_func, 1)
print('Quantized performance: ' + str(perf * 100))
# close session
sess.close()
Code Example #2 Trainable Quantization¶
Required imports
import os
import tensorflow as tf
# Import the tensorflow quantisim
from aimet_tensorflow import quantsim
from aimet_tensorflow.common import tfrecord_generator as tf_gen
from aimet_tensorflow.common import graph_eval
from aimet_tensorflow.utils import graph_saver
from aimet_common.defs import QuantScheme
from tensorflow.examples.tutorials.mnist import input_data
Evaluation function to be used for computing initial encodings
def evaluate_model(sess: tf.compat.v1.Session, eval_iterations: int, use_cuda: bool) -> float:
"""
This is intended to be the user-defined model evaluation function.
AIMET requires the above signature. So if the user's eval function does not
match this signature, please create a simple wrapper.
Note: Honoring the number of iterations is not absolutely necessary.
However if all evaluations run over an entire epoch of validation data,
the runtime for AIMET compression will obviously be higher.
:param sess: Tensorflow session
:param eval_iterations: Number of iterations to use for evaluation.
None for entire epoch.
:param use_cuda: If true, evaluate using gpu acceleration
:return: single float number (accuracy) representing model's performance
"""
# Evaluate model should run data through the model and return an accuracy score.
# If the model does not have nodes to measure accuracy, they will need to be added to the graph.
return .5
Quantize and fine-tune a trained model to learn min max ranges
def quantization_aware_training_range_learning(forward_pass):
"""
Running Quantize Range Learning Test
"""
tf.reset_default_graph()
# Allocate the generator you wish to use to provide the network with data
parser2 = tf_gen.MnistParser(batch_size=100, data_inputs=['reshape_input'])
generator = tf_gen.TfRecordGenerator(tfrecords=[os.path.join('data', 'mnist', 'validation.tfrecords')],
parser=parser2)
sess = graph_saver.load_model_from_meta('models/mnist_save.meta', 'models/mnist_save')
# Create quantsim model to quantize the network using the default 8 bit params/activations
# quant scheme set to range learning
sim = quantsim.QuantizationSimModel(sess, ['reshape_input'], ['dense_1/BiasAdd'],
quant_scheme=QuantScheme.training_range_learning_with_tf_init)
# Initialize the model with encodings
sim.compute_encodings(forward_pass, forward_pass_callback_args=1)
# Train the model to fine-tune the encodings
g = sim.session.graph
sess = sim.session
with g.as_default():
parser2 = tf_gen.MnistParser(batch_size=100, data_inputs=['reshape_input'])
generator2 = tf_gen.TfRecordGenerator(tfrecords=['data/mnist/validation.tfrecords'], parser=parser2)
cross_entropy = g.get_operation_by_name('xent')
train_step = g.get_operation_by_name("Adam")
# do training: learn weights and architecture simultaneously
x = sim.session.graph.get_tensor_by_name("reshape_input:0")
y = g.get_tensor_by_name("labels:0")
fc1_w = g.get_tensor_by_name("dense_1/MatMul/ReadVariableOp:0")
perf = graph_eval.evaluate_graph(sess, generator2, ['accuracy'], graph_eval.default_eval_func, 1)
print('Quantized performance: ' + str(perf * 100))
ce = g.get_tensor_by_name("xent:0")
train_step = tf.train.AdamOptimizer(1e-3, name="TempAdam").minimize(ce)
graph_eval.initialize_uninitialized_vars(sess)
mnist = input_data.read_data_sets('./data', one_hot=True)
for i in range(100):
batch = mnist.train.next_batch(50)
sess.run([train_step, fc1_w], feed_dict={x: batch[0], y: batch[1]})
if i % 10 == 0:
perf = graph_eval.evaluate_graph(sess, generator2, ['accuracy'], graph_eval.default_eval_func, 1)
print('Quantized performance: ' + str(perf * 100))
# close session
sess.close()