AIMET TensorFlow Quantization SIM API

Top-level API

class aimet_tensorflow.quantsim.QuantizationSimModel(session, starting_op_names, output_op_names, quant_scheme='tf_enhanced', rounding_mode='nearest', default_output_bw=8, default_param_bw=8, use_cuda=True, config_file=None,[source]

Creates a QuantSim model by adding quantization simulations ops to a given model.

This enables

  1. off-target simulation of inference accuracy

  2. the model to be fine-tuned to counter the effects of quantization

  • session (Session) – The input model as session to add quantize ops to

  • starting_op_names (List[str]) – List of starting op names of the model

  • output_op_names (List[str]) – List of output op names of the model

  • quant_scheme (Union[str, QuantScheme]) – Quantization Scheme, currently supported schemes are post_training_tf and post_training_tf_enhanced, defaults to post_training_tf_enhanced

  • rounding_mode (str) – The round scheme to used. One of: ‘nearest’ or ‘stochastic’, defaults to ‘nearest’.

  • default_output_bw (int) – bitwidth to use for activation tensors, defaults to 8

  • default_param_bw (int) – bitwidth to use for parameter tensors, defaults to 8

  • use_cuda (bool) – If True, places quantization ops on GPU. Defaults to True

  • config_file (Optional[str]) – Path to a config file to use to specify rules for placing quant ops in the model

  • default_data_type (QuantizationDataType) – Default data type to use for quantizing all layer parameters. Possible options are and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 and default_param_bw=16


An object which can be used to perform quantization on a tensorflow graph


ValueError: An error occurred processing one of the input parameters.

Note about Quantization SchemesAIMET offers multiple Quantization Schemes-
  1. Post Training Quantization- The encodings of the model are computed using TF or TF-Enhanced scheme

  2. Trainable Quantization- The min max of encodings are learnt during training.
    • Range Learning with TF initialization - Uses TF scheme to initialize the encodings and then during training these encodings are fine-tuned to improve accuracy of the model

    • Range Learning with TF-Enhanced initialization - Uses TF-Enhanced scheme to initialize the encodings and then during training these encodings are fine-tuned to improve accuracy of the model

The following API can be used to Compute Encodings for Model

QuantizationSimModel.compute_encodings(forward_pass_callback, forward_pass_callback_args)[source]

Computes encodings for all quantization sim nodes in the model. This is also used to set initial encodings for Range Learning.

  • forward_pass_callback (Callable[[Session, Any], None]) – A callback function that is expected to runs forward passes on a session. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. This callback internally chooses the number of data samples it wants to use for calculating encodings.

  • forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex.



The following API can be used to Export the Model to target

QuantizationSimModel.export(path, filename_prefix, orig_sess=None)[source]

This method exports out the quant-sim model so it is ready to be run on-target.

Specifically, the following are saved

  1. The sim-model is exported to a regular tensorflow meta/checkpoint without any simulation ops

  2. The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired)

  • path (str) – path where to store model pth and encodings

  • filename_prefix (str) – Prefix to use for filenames of the model pth and encodings files

  • orig_sess (Optional[Session]) – optional param to pass in original session without quant nodes for export



Encoding format is described in the Quantization Encoding Specification

Code Examples

Required imports

import tensorflow as tf

# Import the tensorflow quantisim
from aimet_tensorflow import quantsim
from aimet_tensorflow.common import graph_eval
from aimet_tensorflow.utils import graph_saver
from aimet_common.defs import QuantScheme
from tensorflow.examples.tutorials.mnist import input_data

User should write this function to pass calibration data

def pass_calibration_data(session: tf.Session):
    The User of the QuantizationSimModel API is expected to write this function based on their data set.
    This is not a working function and is provided only as a guideline.

    :param session: Model's session

    # User action required
    # The following line of code is an example of how to use the ImageNet data's validation data loader.
    # Replace the following line with your own dataset's validation data loader.
    data_loader = None  # Your Dataset's data loader

    # User action required
    # For computing the activation encodings, around 1000 unlabelled data samples are required.
    # Edit the following 2 lines based on your dataloader's batch size.
    # batch_size * max_batch_counter should be 1024
    batch_size = 64
    max_batch_counter = 16

    input_tensor = None  # input tensor in session
    train_tensor = None  # train tensor in session

    current_batch_counter = 0
    for input_data, _ in data_loader:
        feed_dict = {input_tensor: input_data,
                     train_tensor: False}[], feed_dict=feed_dict)

        current_batch_counter += 1
        if current_batch_counter == max_batch_counter:

Quantize the model and finetune (QAT)

def quantize_model():
    Create the Quantization Simulation and finetune the model.

    # load graph
    sess = graph_saver.load_model_from_meta('models/mnist_save.meta', 'models/mnist_save')

    # Create quantsim model to quantize the network using the default 8 bit params/activations
    sim = quantsim.QuantizationSimModel(sess, starting_op_names=['reshape_input'], output_op_names=['dense_1/BiasAdd'],

    # Compute encodings
    sim.compute_encodings(pass_calibration_data, forward_pass_callback_args=None)

    # Do some finetuning

    # User action required
    # The following line of code illustrates that the model is getting finetuned.
    # Replace the following train() function with your pipeline's train() function.

Quantize and finetune a trained model learn the encodings (Range Learning)

def quantization_aware_training_range_learning():
    Running Quantize Range Learning Test

    # Allocate the generator you wish to use to provide the network with data
    parser2 = tf_gen.MnistParser(batch_size=100, data_inputs=['reshape_input'])
    generator = tf_gen.TfRecordGenerator(tfrecords=[os.path.join('data', 'mnist', 'validation.tfrecords')],

    sess = graph_saver.load_model_from_meta('models/mnist_save.meta', 'models/mnist_save')

    # Create quantsim model to quantize the network using the default 8 bit params/activations
    # quant scheme set to range learning
    sim = quantsim.QuantizationSimModel(sess, ['reshape_input'], ['dense_1/BiasAdd'],

    # Initialize the model with encodings
    sim.compute_encodings(pass_calibration_data, forward_pass_callback_args=None)

    # Train the model to fine-tune the encodings
    g = sim.session.graph
    sess = sim.session

    with g.as_default():

        parser2 = tf_gen.MnistParser(batch_size=100, data_inputs=['reshape_input'])
        generator2 = tf_gen.TfRecordGenerator(tfrecords=['data/mnist/validation.tfrecords'], parser=parser2)
        cross_entropy = g.get_operation_by_name('xent')
        train_step = g.get_operation_by_name("Adam")

        # do training: learn weights and architecture simultaneously
        x = sim.session.graph.get_tensor_by_name("reshape_input:0")
        y = g.get_tensor_by_name("labels:0")
        fc1_w = g.get_tensor_by_name("dense_1/MatMul/ReadVariableOp:0")

        perf = graph_eval.evaluate_graph(sess, generator2, ['accuracy'], graph_eval.default_eval_func, 1)
        print('Quantized performance: ' + str(perf * 100))

        ce = g.get_tensor_by_name("xent:0")
        train_step = tf.train.AdamOptimizer(1e-3, name="TempAdam").minimize(ce)
        mnist = input_data.read_data_sets('./data', one_hot=True)

        for i in range(100):
            batch = mnist.train.next_batch(50)
  [train_step, fc1_w], feed_dict={x: batch[0], y: batch[1]})
            if i % 10 == 0:
                perf = graph_eval.evaluate_graph(sess, generator2, ['accuracy'], graph_eval.default_eval_func, 1)
                print('Quantized performance: ' + str(perf * 100))

    # close session