AIMET Keras Quantization SIM API

Top-level API

class aimet_tensorflow.keras.quantsim.QuantizationSimModel(model, quant_scheme='tf_enhanced', rounding_mode='nearest', default_output_bw=8, default_param_bw=8, in_place=False, config_file=None, default_data_type=QuantizationDataType.int)[source]

Implements mechanism to add quantization simulations ops to a model. This allows for off-target simulation of inference accuracy. Also allows the model to be fine-tuned to counter the effects of quantization.

Parameters:
  • model – Model to quantize

  • quant_scheme (Union[QuantScheme, str]) – Quantization Scheme, currently supported schemes are post_training_tf and post_training_tf_enhanced, defaults to post_training_tf_enhanced

  • rounding_mode (str) – The round scheme to used. One of: ‘nearest’ or ‘stochastic’, defaults to ‘nearest’.

  • default_output_bw (int) – bitwidth to use for activation tensors, defaults to 8

  • default_param_bw (int) – bitwidth to use for parameter tensors, defaults to 8

  • in_place (bool) – If True, then the given ‘model’ is modified in-place to add quant-sim nodes. Only suggested use of this option is when the user wants to avoid creating a copy of the model

  • config_file (Optional[str]) – Path to a config file to use to specify rules for placing quant ops in the model

  • default_data_type (QuantizationDataType) – Default data type to use for quantizing all layer parameters. Possible options are QuantizationDataType.int and QuantizationDataType.float. Note that the mode default_data_type=QuantizationDataType.float is only supported with default_output_bw=16 and default_param_bw=16


The following API can be used to Compute Encodings for Model

QuantizationSimModel.compute_encodings(forward_pass_callback, forward_pass_callback_args)[source]

Computes encodings for all quantization sim nodes in the model. :type forward_pass_callback: :param forward_pass_callback: A callback function that is expected to runs forward passes on a model.

This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples.

Parameters:

forward_pass_callback_args – These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex.


The following API can be used to Export the Model to target

QuantizationSimModel.export(path, filename_prefix, custom_objects=None, convert_to_pb=True)[source]

This method exports out the quant-sim model so it is ready to be run on-target. Specifically, the following are saved 1. The sim-model is exported to a regular Keras model without any simulation ops 2. The quantization encodings are exported to a separate JSON-formatted file that can

then be imported by the on-target runtime (if desired)

Parameters:
  • path – path where to store model pth and encodings

  • filename_prefix – Prefix to use for filenames of the model pth and encodings files

  • custom_objects – If there are custom objects to load, Keras needs a dict of them to map them


Encoding format is described in the Quantization Encoding Specification


Code Examples

Required imports

import numpy as np
import tensorflow as tf

from aimet_tensorflow.keras import quantsim

Quantize with Fine tuning

def quantize_model():
    model = tf.keras.applications.resnet50.ResNet50(weights=None, classes=10)
    sim = quantsim.QuantizationSimModel(model)

    # Generate some dummy data
    dummy_x = np.random.randn(10, 224, 224, 3)
    dummy_y = np.random.randint(0, 10, size=(10,))
    dummy_y = tf.keras.utils.to_categorical(dummy_y, num_classes=10)

    # Compute encodings
    sim.model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),loss='categorical_crossentropy',metrics=['accuracy'])
    sim.compute_encodings(evaluate, forward_pass_callback_args=(dummy_x, dummy_y))

    # Do some fine-tuning
    sim.model.fit(x=dummy_x, y=dummy_y, epochs=10)