Convert SavedModel to TFLite Model with Signatures
While using TFLite Interpreter in Android, it’s not clear how to generate a TFLite model with desired input/output names (i.e., signatures), before using runSignature
API .
In this article, we will quickly go through the steps for setting TFLite model signatures and provide necessary documents that explain how all these concepts are related.
Update (2021/10/10)
Google has already provided a guide page to Signatures:
https://www.tensorflow.org/lite/guide/signatures
The preferred way to run inference
The preferred way to run inference on a model is to use signatures — Available for models converted starting Tensorflow 2.5
The runSignature
method takes three arguments:
- Inputs : map for inputs from input name in the signature to an input object.
- Outputs : map for output mapping from output name in signature to output data.
- Signature Name [optional]: Signature name (Can be left empty if the model has single signature).
https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_java
Understand SignatureDef Structure
The arguments of runSignature
are actually representing a SignatureDef.
A SignatureDef requires specification of:
- inputs as a map of string to TensorInfo.
- outputs as a map of string to TensorInfo.
- method_name (which corresponds to a supported method name in the loading tool/system).
Note that TensorInfo itself requires specification of name, dtype and tensor shape.
https://www.tensorflow.org/tfx/serving/signature_defs
How to specify the SignatureDef to a model?
https://www.tensorflow.org/api_docs/python/tf/lite/TFLiteConverter#from_saved_model
signature_keys
List of keys identifying SignatureDef containing inputs and outputs. Elements should not be duplicated. By default the signatures attribute of the MetaGraphdef is used. (default saved_model.signatures)
Basically, you just have to specify the signatures while saving the model by tf.saved_model.save(...)
, and the corresponding SignatureDef will be generated by default.
You can also specify the signatures in keras model.save(...)
. However, You have to inherit the Model class and override or create the method for graph tracing, which is not desired.
The recommend way is the following.
Save Model/Module with Signatures
Simple sample code can be found at:
https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_python
Generally, you can wrap your SavedModels inside the tf.Module.
- Note, if your SavedModel is a keras Model, you should use
keras.models.load_model(...)
. - Otherwise, if you use
saved_model.load(...)
, you will get trouble with the output shapes.
More general sample code:
class ModelWrapperModule(tf.Module):def __init__(self, my_model_a, my_model_b): super().__init__() # Wrap your models self.my_model_a = my_model_a self.my_model_b = my_model_b
@tf.functiondef __call__(self, x_input): # x_input will be the input name of SignatureDefs. return { # Dictionary's keys will be output names of SignatureDefs. 'my_output_a': self.my_model_a(x_input), 'my_output_b': self.my_model_b(x_input) }# Assume these models are keras model.model_a = keras.models.load_model(str(model_a_path)model_b = keras.models.load_model(str(model_b_path)module = ModelWrapperModule(model_a, model_b)'''Trace the graph, determinate the shapes and types.If your model have batch size None, don't change it to other number.In my experience, it will crash while loading the tflite model in Android.'''call_output = module.__call__.get_concrete_function(tf.TensorSpec((None, 128, 3), tf.float32))# The dictionary keys are method names/signature names of SignatureDefs.tf.saved_model.save(module, 'module_filename', signatures={'sample_method_name': call_output})
Convert into TFLite
converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_path))converter.optimizations = [tf.lite.Optimize.DEFAULT] # optionalconverter.target_spec.supported_types = [tf.float16] # optionaltflite_file_path = saved_model_path.parent / f'{saved_model_path.stem}.tflite'tflite_file_path.write_bytes(converter.convert())
Verify by Interpreter
interpreter = tf.lite.Interpreter(str(tflite_file_path))print('Signature List:')print(interpreter.get_signature_list())print('Input Details:')print(interpreter.get_input_details())print('Ouput Details:')print(interpreter.get_output_details())
Other References
https://www.tensorflow.org/guide/saved_model#specifying_signatures_during_export
https://www.tensorflow.org/api_docs/python/tf/saved_model/save