Model Preparer API

AIMET Keras ModelPreparer API is used to prepare a Keras model that is not using the Keras Functional or Sequential API. Specifically, it targets models that have been created using the subclassing feature in Keras. The ModelPreparer API will convert the subclassing model to a Keras Functional API model. This is required because the AIMET Keras Quantization API requires a Keras Functional API model as input.

Users are strongly encouraged to use AIMET Keras ModelPreparer API first and then use the returned model as input to all the AIMET Quantization features. It is manditory to use the AIMET Keras ModelPreparer API if the model is created using the subclassing feature in Keras, if any of the submodules of the model are created via subclassing, or if any custom layers that inherit from the Keras Layer class are used in the model.

Top-level API

aimet_tensorflow.keras.model_preparer.prepare_model(original_model, input_layer=None)[source]

This function prepares a Keras model before continuing on with AIMET. Specifically, it will convert the model into a purely Functional API model and copy over the original models weights.

Parameters
  • original_model (Model) – The original model to be prepared

  • input_layer (Union[InputLayer, List[InputLayer], None]) – The input layer to be used for the new model. By default, the input layer is set to None. If the

beginning portion of the model is subclassed, then the input layer must be passed in. :rtype: Model :return: The prepared model if needed, or the original model

Code Examples

Required imports

Example 1: Model with Two Subclassed Layers

We begin with a model that has two subclassed layers - TokenAndPositionEmbedding and TransformerBlock. This model is taken from the Transformer text classification example.

class TokenAndPositionEmbedding(tf.keras.layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.token_emb = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = tf.keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x, **kwargs):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        x = x + positions
        return x
class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = tf.keras.Sequential(
            [tf.keras.layers.Dense(ff_dim, activation="relu"), tf.keras.layers.Dense(embed_dim),]
        )
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, inputs, training, **kwargs):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)
def get_text_classificaiton_model() -> tf.keras.Model:
    vocab_size = 20000 
    maxlen = 200

    random_input = np.random.random((10, 200)) # Random input to build the model

    embed_dim = 32  # Embedding size for each token
    num_heads = 2  # Number of attention heads
    ff_dim = 32  # Hidden layer size in feed forward network inside transformer

    inputs = tf.keras.layers.Input(shape=(maxlen,))
    embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
    x = embedding_layer(inputs)
    transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
    x = transformer_block(x)
    x = tf.keras.layers.GlobalAveragePooling1D()(x)
    x = tf.keras.layers.Dropout(0.1)(x)
    x = tf.keras.layers.Dense(20, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.1)(x)
    outputs = tf.keras.layers.Dense(2, activation="softmax")(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    _ = model(random_input)
    return model

Run the model preparer API on the model by passing in the model.

def model_preparer_two_subclassed_layers() -> tf.keras.Model:
    model = get_text_classificaiton_model()
    model = prepare_model(model)
    return model

The model preparer API will return a Keras Functional API model. We can now use this model as input to the AIMET Keras Quantization API.

Example 2: Model with Subclassed Layer as First Layer

def get_subclass_model_with_functional_layers() -> tf.keras.Model:
    inputs = tf.keras.Input(shape=(64,))
    outputs = tf.keras.layers.Dense(1, activation="sigmoid")(inputs)
    binary_classifier = tf.keras.Model(inputs=inputs, outputs=outputs)

    class MyFunctionalModel(tf.keras.Model):
        def __init__(self):
            super().__init__(name='my_functional_model')
            self.dense = tf.keras.layers.Dense(64, activation="relu")
            self.classifier = binary_classifier

        def call(self, inputs, **kwargs):
            features = self.dense(inputs)
            return self.classifier(features)

    model = MyFunctionalModel()
    return model

Run the model preparer API on the model by passing in the model and an Input Layer. Note that this is an example of when the model preparer API will require an Input Layer as input.

def model_preparer_subclassed_model_with_functional_layers():
    model = get_subclass_model_with_functional_layers()
    model = prepare_model(model, input_layer=tf.keras.Input(shape=(64,))) # Note: input layer is passed in
    return model

The model preparer API will return a Keras Functional API model. We can now use this model as input to the AIMET Keras Quantization API.

Limitations

The AIMET Keras ModelPreparer API has the following limitations:

  • If the model starts with a subclassed layer, the AIMET Keras ModelPreparer API will need an Keras Input Layer as input. This is becuase the Keras Functional API requires an Input Layer as the first layer in the model. The AIMET Keras ModelPreparer API will raise an exception if the model starts with a subclassed layer and an Input Layer is not provided as input.

  • The AIMET Keras ModelPreparer API is able to convert subclass layers that have arthmetic experssion in their call function. However, this API and Keras, will convert these operations to TFOPLambda layers which are not currently supported by AIMET Keras Quantization API. If possible, it is recommended to have the subclass layers call function resemble the Keras Functional API layers.

    For example, if a subclass layer has two convolution layers in its call function, the call function should look like the following:

    def call(self, x, **kwargs):
        x = self.conv_1(x)
        x = self.conv_2(x)
        return x
    
  • Subclass layers are pieces of Python code in contrast to typical Functional or Sequential models are static graphs of layers. Due to this, the subclass layers do not have this same attribute and can cause some issues during the model preparer. The model preparer utilizes the call function of a subclass layer to trace out the layers defined inside of it. To do this, a Keras Symbolic Tensor is passed through. If this symbolic tensor does not “touch” all parts of the layers defined inside, this can cause missing layers/weights when preparing the model. In the example below we can see that in the first call function, we would run into this error. The Keras Symbolic Tensor represented with variable x, does not pass through the position’s variable at any point. This results in the weight for self.pos_emb to be missing in the final prepared model. In contrast, the second call function has the input layer go through the entirety of the layers and allows the model preparer to pick up all the internal weights and layers.:

    def call(self, x, **kwargs):
        positions = tf.range(start=0, limit=self.static_patch_count, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        x = x + positions
        return x
    
    def call(self, x, **kwargs):
        maxlen = tf.shape( x )[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb( x )
        x = x + positions
        return x