Using AIMET Tensorflow APIs with Keras Models

Introduction

Currently AIMET APIs support Tensorflow sessions. This example code shows a method for how to use AIMET if you have a Keras model by invoking AIMET on the back-end session and converting the returned session to a Keras model.

APIs

The method involves performing four steps. The steps are:

Step 1: Save the session returned by AIMET

aimet_tensorflow.utils.convert_tf_sess_to_keras.save_tf_session_single_gpu(sess, path, input_tensor, output_tensor)[source]

Saves TF session, meta graph and variables in the provided path

Parameters:
  • sess (Session) – Input: tf.compat.v1.Session

  • path (str) – Path to save the session

  • input_tensor (str) – Name of starting op to the given graph

  • output_tensor (str) – Name of output op of the graph

Returns:

None


Step 2: Model subclassing to load the corresponding session to Keras model

aimet_tensorflow.utils.convert_tf_sess_to_keras.load_tf_sess_variables_to_keras_single_gpu(path, compressed_ops)[source]

Creates a Keras model subclass and loads the saved session, meta graph and variables to Keras model

Parameters:
  • path (str) – Path to load the tf session saved using save_session_graph_and_variables

  • compressed_ops (List[str]) – List of ops names skipped in Keras model creations. These are the the ops that AIMET compressed and are isolated from rest of the graph.

Return type:

Model

Returns:

Subclassed Keras Model


After these two steps, model can be used for single gpu training. For multi-gpu training, the next two steps needs to be followed:

Step 3: Saving the Keras model from step 2 to make it compatible with distribution strategy

aimet_tensorflow.utils.convert_tf_sess_to_keras.save_as_tf_module_multi_gpu(loading_path, saving_path, compressed_ops, input_shape)[source]

Loads a Keras model and re-saves the loaded object in the form of tf.Module

Parameters:
  • loading_path (str) – Path to load the Keras Model

  • saving_path (str) – Path to save the object

  • compressed_ops (List[str]) – List of ops names for which we need to skip in Keras model creation. These are the the ops that AIMET compressed and are isolated from rest of the graph.

  • input_shape (Tuple) – shape of input to the model

Returns:

None


Step 4: Model subclassing to load the corresponding Keras model

aimet_tensorflow.utils.convert_tf_sess_to_keras.load_keras_model_multi_gpu(loading_path, input_shape)[source]

This function loads the Keras model back, which can be used for funetuning within a strategy

Parameters:
  • loading_path (str) – Path to load the Keras Model

  • input_shape (List) – the shape of stating tensor in graph ; for instance (224,224,3) for ResNet50 and MoblinetV1

Returns:

subclassed Keras model


Code Example

Required imports

import tensorflow as tf
from aimet_tensorflow.utils.convert_tf_sess_to_keras import save_tf_session_single_gpu, save_as_tf_module_multi_gpu, \
    load_tf_sess_variables_to_keras_single_gpu, load_keras_model_multi_gpu

Steps to convert a TF session found after compression to Keras model

def convert_tf_session_to_keras_model():
    """
    Convert an AIMET  spatial SVD compressed session to a Keras model and train the Keras model with MirroredStrategy
    """
    sess = get_sess_from_keras_model()

    # For instance, if the first conv layer in MobilNetV1 graph is compressed, then:
    compressed_ops = ['conv1/Conv2D']
    compressed_sess = compress_session(sess, compressed_ops)

    # Defining the input and output convs of the session for MobileNet model
    input_op_name, output_op_name = "input_1:0", "act_softmax/Softmax:0"

    # Step 1: Single Saving the compressed session
    path = './saved_model_single_gpu'
    save_tf_session_single_gpu(compressed_sess, path, input_op_name, output_op_name)
    tf.keras.backend.clear_session()

    # Step 2: Loading the correspnding Keras Model
    tf.keras.backend.set_learning_phase(1)
    model = load_tf_sess_variables_to_keras_single_gpu(path, compressed_ops)

    # Single GPU training of the loaded Keras Model
    train(model)

    # To be able to do multi-gpu training the next two steps needs to be followed:
    # Step 3: Re-Saving the Keras model to make it compatible with distribution strategy
    saving_path = './saved_model_multi_gpu'
    save_as_tf_module_multi_gpu(path, saving_path, compressed_ops, input_shape=(224, 224, 3))

    tf.keras.backend.clear_session()

    with tf.distribute.MirroredStrategy().scope():
        tf.keras.backend.set_learning_phase(1)
        # Step 4: Loading the keras model and  Multi gpu training the model on given dataset
        model = load_keras_model_multi_gpu(saving_path, input_shape=[224, 224, 3])
        # Train model on Multi-GPU
        train(model)

Utility Functions

Required imports

import tensorflow as tf
from tensorflow.keras.applications import MobileNet
from keras.applications.vgg16 import preprocess_input

import numpy as np

from aimet_common.defs import CompressionScheme, CostMetric
from aimet_tensorflow.defs import SpatialSvdParameters
from aimet_tensorflow.compress import ModelCompressor
from aimet_tensorflow.defs import ModuleCompRatioPair

Utility function to get session from Keras model

def get_sess_from_keras_model():
    """
    Gets TF session from keras model
    :return: TF session
    """
    tf.keras.backend.clear_session()
    tf.keras.backend.set_learning_phase(1)
    _ = MobileNet(weights=None, input_shape=(224, 224, 3))
    sess = tf.compat.v1.keras.backend.get_session()
    return sess

Utility function to get a compressed session

def compress_session(sess, compressible_ops):
    """
    Compressed TF session
    :param sess: Tf session
    :param compressible_ops: layers to compress
    :return: compressed session
    """
    layer_a = sess.graph.get_operation_by_name(compressible_ops[0])
    list_of_module_comp_ratio_pairs = [ModuleCompRatioPair(layer_a, 0.5)]
    manual_params = SpatialSvdParameters.ManualModeParams(
        list_of_module_comp_ratio_pairs=list_of_module_comp_ratio_pairs)
    params = SpatialSvdParameters(input_op_names=['input_1'], output_op_names=['act_softmax/Softmax'],
                                  mode=SpatialSvdParameters.Mode.manual, params=manual_params)
    scheme = CompressionScheme.spatial_svd
    metric = CostMetric.mac

    # pylint: disable=unused-argument
    def evaluate(sess, iterations, use_cuda):
        return 1

    sess, _ = ModelCompressor.compress_model(sess=sess,
                                             working_dir="./",
                                             eval_callback=evaluate,
                                             eval_iterations=None,
                                             input_shape=(1, 3, 224, 224),
                                             compress_scheme=scheme,
                                             cost_metric=metric,
                                             parameters=params)
    return sess

Utility function for training

def train(model):
    """
    Trains using fake dataset
    :param model: Keras model
    :return: trained model
    """
    # Create a fake dataset
    x_train = np.random.rand(32, 224, 224, 3)
    y_train = np.random.rand(32, )
    x_train = preprocess_input(x_train)
    y_train = tf.keras.utils.to_categorical(y_train, 1000)

    model.compile('rmsprop', 'mse')
    model.fit(x_train, y_train, epochs=1, batch_size=1, shuffle=False)
    return model