QuantSim API for TensorFlow¶
Top-level API¶
- 
class aimet_tensorflow.quantsim.QuantizationSimModel(session, starting_op_names, output_op_names, quant_scheme='tf_enhanced', rounding_mode='nearest', default_output_bw=8, default_param_bw=8, use_cuda=True, config_file=None)¶
- Creates a QuantSim model by adding quantization simulations ops to a given model. - This enables - off-target simulation of inference accuracy 
- the model to be fine-tuned to counter the effects of quantization 
 - Parameters
- session ( - Session) – The input model as session to add quantize ops to
- starting_op_names ( - List[- str]) – List of starting op names of the model
- output_op_names ( - List[- str]) – List of output op names of the model
- quant_scheme ( - Union[- str,- QuantScheme]) – Quantization Scheme, currently supported schemes are post_training_tf and post_training_tf_enhanced, defaults to post_training_tf_enhanced
- rounding_mode ( - str) – The round scheme to used. One of: ‘nearest’ or ‘stochastic’, defaults to ‘nearest’.
- default_output_bw ( - int) – bitwidth to use for activation tensors, defaults to 8
- default_param_bw ( - int) – bitwidth to use for parameter tensors, defaults to 8
- use_cuda ( - bool) – If True, places quantization ops on GPU. Defaults to True
- config_file ( - Optional[- str]) – Path to a config file to use to specify rules for placing quant ops in the model
 
- Returns
- An object which can be used to perform quantization on a tensorflow graph 
- Raises
- ValueError: An error occurred processing one of the input parameters. 
 
The following API can be used to Compute Encodings for Model
- 
QuantizationSimModel.compute_encodings(forward_pass_callback, forward_pass_callback_args)¶
- Computes encodings for all quantization sim nodes in the model. This is also used to set initial encodings for Range Learning. - Parameters
- forward_pass_callback ( - Callable[[- Session,- Any],- None]) – A callback function that is expected to runs forward passes on a session. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. This callback internally chooses the number of data samples it wants to use for calculating encodings.
- forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex. 
 
- Returns
- None 
 
The following API can be used to Export the Model to target
- 
QuantizationSimModel.export(path, filename_prefix, orig_sess=None)¶
- This method exports out the quant-sim model so it is ready to be run on-target. - Specifically, the following are saved - The sim-model is exported to a regular tensorflow meta/checkpoint without any simulation ops 
- The quantization encodings are exported to a separate JSON-formatted file that can then be imported by the on-target runtime (if desired) 
 - Parameters
- path ( - str) – path where to store model pth and encodings
- filename_prefix ( - str) – Prefix to use for filenames of the model pth and encodings files
- orig_sess ( - Optional[- Session]) – optional param to pass in original session without quant nodes for export
 
- Returns
- None 
 
Code Examples¶
Required imports
import tensorflow as tf
# Import the tensorflow quantisim
from aimet_tensorflow import quantsim
from aimet_tensorflow.common import graph_eval
from aimet_tensorflow.utils import graph_saver
from aimet_common.defs import QuantScheme
from tensorflow.examples.tutorials.mnist import input_data
Quantize with Range Learning
def quantize_model(generator):
    tf.compat.v1.reset_default_graph()
    # load graph
    sess = graph_saver.load_model_from_meta('models/mnist_save.meta', 'models/mnist_save')
    def forward_callback(session, iterations):
        graph_eval.evaluate_graph(session, generator, ['accuracy'], graph_eval.default_eval_func, iterations)
    # Create quantsim model to quantize the network using the default 8 bit params/activations
    sim = quantsim.QuantizationSimModel(sess, starting_op_names=['reshape_input'], output_op_names=['dense_1/BiasAdd'],
                                        quant_scheme=QuantScheme.post_training_tf_enhanced,
                                        config_file='../../../TrainingExtensions/common/src/python/aimet_common/'
                                                    'quantsim_config/default_config.json')
    # Compute encodings
    sim.compute_encodings(forward_callback, forward_pass_callback_args=1)
    # Do some fine-tuning
    training_helper(sim, generator)
Example Fine-tuning step
def training_helper(sim, generator):
    """A Helper function to fine-tune MNIST model"""
    g = sim.session.graph
    sess = sim.session
    with g.as_default():
        x = sim.session.graph.get_tensor_by_name("reshape_input:0")
        y = g.get_tensor_by_name("labels:0")
        fc1_w = g.get_tensor_by_name("dense_1/MatMul/ReadVariableOp:0")
        ce = g.get_tensor_by_name("xent:0")
        # Using Adam optimizer
        train_step = tf.compat.v1.train.AdamOptimizer(1e-3, name="TempAdam").minimize(ce)
        graph_eval.initialize_uninitialized_vars(sess)
        # Input data for MNIST
        mnist = input_data.read_data_sets('./data', one_hot=True)
        # Using 100 iterations and batch of size 50
        for i in range(100):
            batch = mnist.train.next_batch(50)
            sess.run([train_step, fc1_w], feed_dict={x: batch[0], y: batch[1]})
            if i % 10 == 0:
                # Find accuracy of model every 10 iterations
                perf = graph_eval.evaluate_graph(sess, generator, ['accuracy'], graph_eval.default_eval_func, 1)
                print('Quantized performance: ' + str(perf * 100))
    # close session
    sess.close()
