Using AIMET Tensorflow APIs with Keras Models
Introduction
Currently AIMET APIs support Tensorflow sessions. This example code shows a method for how to use AIMET if you have a Keras model by invoking AIMET on the back-end session and converting the returned session to a Keras model.
APIs
The method involves performing four steps. The steps are:
Step 1: Save the session returned by AIMET
- aimet_tensorflow.utils.convert_tf_sess_to_keras.save_tf_session_single_gpu(sess, path, input_tensor, output_tensor)[source]
Saves TF session, meta graph and variables in the provided path
- Parameters
sess (
Session
) – Input: tf.compat.v1.Sessionpath (
str
) – Path to save the sessioninput_tensor (
str
) – Name of starting op to the given graphoutput_tensor (
str
) – Name of output op of the graph
- Returns
None
Step 2: Model subclassing to load the corresponding session to Keras model
- aimet_tensorflow.utils.convert_tf_sess_to_keras.load_tf_sess_variables_to_keras_single_gpu(path, compressed_ops)[source]
Creates a Keras model subclass and loads the saved session, meta graph and variables to Keras model
- Parameters
path (
str
) – Path to load the tf session saved using save_session_graph_and_variablescompressed_ops (
List
[str
]) – List of ops names skipped in Keras model creations. These are the the ops that AIMET compressed and are isolated from rest of the graph.
- Return type
Model
- Returns
Subclassed Keras Model
After these two steps, model can be used for single gpu training. For multi-gpu training, the next two steps needs to be followed:
Step 3: Saving the Keras model from step 2 to make it compatible with distribution strategy
- aimet_tensorflow.utils.convert_tf_sess_to_keras.save_as_tf_module_multi_gpu(loading_path, saving_path, compressed_ops, input_shape)[source]
Loads a Keras model and re-saves the loaded object in the form of tf.Module
- Parameters
loading_path (
str
) – Path to load the Keras Modelsaving_path (
str
) – Path to save the objectcompressed_ops (
List
[str
]) – List of ops names for which we need to skip in Keras model creation. These are the the ops that AIMET compressed and are isolated from rest of the graph.input_shape (
Tuple
) – shape of input to the model
- Returns
None
Step 4: Model subclassing to load the corresponding Keras model
- aimet_tensorflow.utils.convert_tf_sess_to_keras.load_keras_model_multi_gpu(loading_path, input_shape)[source]
This function loads the Keras model back, which can be used for funetuning within a strategy
- Parameters
loading_path (
str
) – Path to load the Keras Modelinput_shape (
List
) – the shape of stating tensor in graph ; for instance (224,224,3) for ResNet50 and MoblinetV1
- Returns
subclassed Keras model
Code Example
Required imports
import tensorflow as tf
from aimet_tensorflow.utils.convert_tf_sess_to_keras import save_tf_session_single_gpu, save_as_tf_module_multi_gpu, \
load_tf_sess_variables_to_keras_single_gpu, load_keras_model_multi_gpu
Steps to convert a TF session found after compression to Keras model
def convert_tf_session_to_keras_model():
"""
Convert an AIMET spatial SVD compressed session to a Keras model and train the Keras model with MirroredStrategy
"""
sess = get_sess_from_keras_model()
# For instance, if the first conv layer in MobilNetV1 graph is compressed, then:
compressed_ops = ['conv1/Conv2D']
compressed_sess = compress_session(sess, compressed_ops)
# Defining the input and output convs of the session for MobileNet model
input_op_name, output_op_name = "input_1:0", "act_softmax/Softmax:0"
# Step 1: Single Saving the compressed session
path = './saved_model_single_gpu'
save_tf_session_single_gpu(compressed_sess, path, input_op_name, output_op_name)
tf.keras.backend.clear_session()
# Step 2: Loading the correspnding Keras Model
tf.keras.backend.set_learning_phase(1)
model = load_tf_sess_variables_to_keras_single_gpu(path, compressed_ops)
# Single GPU training of the loaded Keras Model
train(model)
# To be able to do multi-gpu training the next two steps needs to be followed:
# Step 3: Re-Saving the Keras model to make it compatible with distribution strategy
saving_path = './saved_model_multi_gpu'
save_as_tf_module_multi_gpu(path, saving_path, compressed_ops, input_shape=(224, 224, 3))
tf.keras.backend.clear_session()
with tf.distribute.MirroredStrategy().scope():
tf.keras.backend.set_learning_phase(1)
# Step 4: Loading the keras model and Multi gpu training the model on given dataset
model = load_keras_model_multi_gpu(saving_path, input_shape=[224, 224, 3])
# Train model on Multi-GPU
train(model)
Utility Functions
Required imports
import tensorflow as tf
from tensorflow.keras.applications import MobileNet
from keras.applications.vgg16 import preprocess_input
import numpy as np
from aimet_common.defs import CompressionScheme, CostMetric
from aimet_tensorflow.defs import SpatialSvdParameters
from aimet_tensorflow.compress import ModelCompressor
from aimet_tensorflow.defs import ModuleCompRatioPair
Utility function to get session from Keras model
def get_sess_from_keras_model():
"""
Gets TF session from keras model
:return: TF session
"""
tf.keras.backend.clear_session()
tf.keras.backend.set_learning_phase(1)
_ = MobileNet(weights=None, input_shape=(224, 224, 3))
sess = tf.compat.v1.keras.backend.get_session()
return sess
Utility function to get a compressed session
def compress_session(sess, compressible_ops):
"""
Compressed TF session
:param sess: Tf session
:param compressible_ops: layers to compress
:return: compressed session
"""
layer_a = sess.graph.get_operation_by_name(compressible_ops[0])
list_of_module_comp_ratio_pairs = [ModuleCompRatioPair(layer_a, 0.5)]
manual_params = SpatialSvdParameters.ManualModeParams(
list_of_module_comp_ratio_pairs=list_of_module_comp_ratio_pairs)
params = SpatialSvdParameters(input_op_names=['input_1'], output_op_names=['act_softmax/Softmax'],
mode=SpatialSvdParameters.Mode.manual, params=manual_params)
scheme = CompressionScheme.spatial_svd
metric = CostMetric.mac
# pylint: disable=unused-argument
def evaluate(sess, iterations, use_cuda):
return 1
sess, _ = ModelCompressor.compress_model(sess=sess,
working_dir="./",
eval_callback=evaluate,
eval_iterations=None,
input_shape=(1, 3, 224, 224),
compress_scheme=scheme,
cost_metric=metric,
parameters=params)
return sess
Utility function for training
def train(model):
"""
Trains using fake dataset
:param model: Keras model
:return: trained model
"""
# Create a fake dataset
x_train = np.random.rand(32, 224, 224, 3)
y_train = np.random.rand(32, )
x_train = preprocess_input(x_train)
y_train = tf.keras.utils.to_categorical(y_train, 1000)
model.compile('rmsprop', 'mse')
model.fit(x_train, y_train, epochs=1, batch_size=1, shuffle=False)
return model