Using AIMET Tensorflow APIs with Keras Models¶
Introduction¶
Currently AIMET APIs support Tensorflow sessions. This example code shows a method for how to use AIMET if you have a Keras model by invoking AIMET on the back-end session and converting the returned session to a Keras model.
APIs¶
The method involves performing four steps. The steps are:
Step 1: Save the session returned by AIMET
-
aimet_tensorflow.utils.convert_tf_sess_to_keras.
save_tf_session_single_gpu
(sess, path, input_tensor, output_tensor)[source]¶ Saves TF session, meta graph and variables in the provided path
- Parameters
sess (<tensorflow.python.client.session.Session object at 0x7f353159db80>) – Input: tf.compat.v1.Session
path (
str
) – Path to save the sessioninput_tensor (
str
) – Name of starting op to the given graphoutput_tensor (
str
) – Name of output op of the graph
- Returns
None
Step 2: Model subclassing to load the corresponding session to Keras model
-
aimet_tensorflow.utils.convert_tf_sess_to_keras.
load_tf_sess_variables_to_keras_single_gpu
(path, compressed_ops)[source]¶ Creates a Keras model subclass and loads the saved session, meta graph and variables to Keras model
- Parameters
path (
str
) – Path to load the tf session saved using save_session_graph_and_variablescompressed_ops (
List
[str
]) – List of ops names skipped in Keras model creations. These are the the ops that AIMET compressed and are isolated from rest of the graph.
- Return type
Model
- Returns
Subclassed Keras Model
After these two steps, model can be used for single gpu training. For multi-gpu training, the next two steps needs to be followed:
Step 3: Saving the Keras model from step 2 to make it compatible with distribution strategy
-
aimet_tensorflow.utils.convert_tf_sess_to_keras.
save_as_tf_module_multi_gpu
(loading_path, saving_path, compressed_ops, input_shape)[source]¶ Loads a Keras model and re-saves the loaded object in the form of tf.Module
- Parameters
loading_path (
str
) – Path to load the Keras Modelsaving_path (
str
) – Path to save the objectcompressed_ops – List of ops names for which we need to skip in Keras model creation. These are the the ops that AIMET compressed and are isolated from rest of the graph.
input_shape (
Tuple
) – shape of input to the model
- Returns
None
Step 4: Model subclassing to load the corresponding Keras model
-
aimet_tensorflow.utils.convert_tf_sess_to_keras.
load_keras_model_multi_gpu
(loading_path, input_shape)[source]¶ This function loads the Keras model back, which can be used for funetuning within a strategy
- Parameters
loading_path (
str
) – Path to load the Keras Modelinput_shape – the shape of stating tensor in graph ; for instance (224,224,3) for ResNet50 and MoblinetV1
- Returns
subclassed Keras model
Code Example¶
Required imports
import tensorflow as tf
from aimet_tensorflow.utils.convert_tf_sess_to_keras import save_tf_session_single_gpu, save_as_tf_module_multi_gpu, \
load_tf_sess_variables_to_keras_single_gpu, load_keras_model_multi_gpu
Steps to convert a TF session found after compression to Keras model
def convert_tf_session_to_keras_model():
"""
Convert an AIMET spatial SVD compressed session to a Keras model and train the Keras model with MirroredStrategy
"""
sess = get_sess_from_keras_model()
# For instance, if the first conv layer in MobilNetV1 graph is compressed, then:
compressed_ops = ['conv1/Conv2D']
compressed_sess = compress_session(sess, compressed_ops)
# Defining the input and output convs of the session for MobileNet model
input_op_name, output_op_name = "input_1:0", "act_softmax/Softmax:0"
# Step 1: Single Saving the compressed session
path = './saved_model_single_gpu'
save_tf_session_single_gpu(compressed_sess, path, input_op_name, output_op_name)
tf.keras.backend.clear_session()
# Step 2: Loading the correspnding Keras Model
tf.keras.backend.set_learning_phase(1)
model = load_tf_sess_variables_to_keras_single_gpu(path, compressed_ops)
# Single GPU training of the loaded Keras Model
train(model)
# To be able to do multi-gpu training the next two steps needs to be followed:
# Step 3: Re-Saving the Keras model to make it compatible with distribution strategy
saving_path = './saved_model_multi_gpu'
save_as_tf_module_multi_gpu(path, saving_path, compressed_ops, input_shape=(224, 224, 3))
tf.keras.backend.clear_session()
with tf.distribute.MirroredStrategy().scope():
tf.keras.backend.set_learning_phase(1)
# Step 4: Loading the keras model and Multi gpu training the model on given dataset
model = load_keras_model_multi_gpu(saving_path, input_shape=[224, 224, 3])
# Train model on Multi-GPU
train(model)
Utility Functions¶
Required imports
import tensorflow as tf
from tensorflow.keras.applications import MobileNet
from keras.applications.vgg16 import preprocess_input
import numpy as np
from aimet_common.defs import CompressionScheme, CostMetric
from aimet_tensorflow.defs import SpatialSvdParameters
from aimet_tensorflow.compress import ModelCompressor
from aimet_tensorflow.defs import ModuleCompRatioPair
Utility function to get session from Keras model
def get_sess_from_keras_model():
"""
Gets TF session from keras model
:return: TF session
"""
tf.keras.backend.clear_session()
tf.keras.backend.set_learning_phase(1)
_ = MobileNet(weights=None, input_shape=(224, 224, 3))
sess = tf.compat.v1.keras.backend.get_session()
return sess
Utility function to get a compressed session
def compress_session(sess, compressible_ops):
"""
Compressed TF session
:param sess: Tf session
:param compressible_ops: layers to compress
:return: compressed session
"""
layer_a = sess.graph.get_operation_by_name(compressible_ops[0])
list_of_module_comp_ratio_pairs = [ModuleCompRatioPair(layer_a, 0.5)]
manual_params = SpatialSvdParameters.ManualModeParams(
list_of_module_comp_ratio_pairs=list_of_module_comp_ratio_pairs)
params = SpatialSvdParameters(input_op_names=['input_1'], output_op_names=['act_softmax/Softmax'],
mode=SpatialSvdParameters.Mode.manual, params=manual_params)
scheme = CompressionScheme.spatial_svd
metric = CostMetric.mac
# pylint: disable=unused-argument
def evaluate(sess, iterations, use_cuda):
return 1
sess, _ = ModelCompressor.compress_model(sess=sess,
working_dir="./",
eval_callback=evaluate,
eval_iterations=None,
input_shape=(1, 3, 224, 224),
compress_scheme=scheme,
cost_metric=metric,
parameters=params)
return sess
Utility function for training
def train(model):
"""
Trains using fake dataset
:param model: Keras model
:return: trained model
"""
# Create a fake dataset
x_train = np.random.rand(32, 224, 224, 3)
y_train = np.random.rand(32, )
x_train = preprocess_input(x_train)
y_train = tf.keras.utils.to_categorical(y_train, 1000)
model.compile('rmsprop', 'mse')
model.fit(x_train, y_train, epochs=1, batch_size=1, shuffle=False)
return model