Overview
This page describes the design and steps needed to convert composite operations in TensorFlow to fused operations in LiteRT. This infrastructure is general purpose and supports conversion of any composite operation in TensorFlow to a corresponding fused operation in LiteRT.
An example use of this infrastructure is TensorFlow RNN operation fusion to LiteRT, as detailed here.
What are fused operations
TensorFlow operations can either be primitive ops e.g. tf.add or they can be composed from other primitive operations e.g. tf.einsum. A primitive operation shows up as a single node in the TensorFlow graph while a composite operation is a collection of nodes in the TensorFlow graph. Executing a composite operation is equivalent to executing each of its constituent primitive operations.
A fused operation corresponds to a single operation that subsumes all the computation performed by each primitive operation within the corresponding composite operation.
Benefits of fused operations
Fused operations exist to maximize the performance of their underlying kernel implementations, by optimizing the overall computation and reducing memory footprint. This is very valuable, especially for low-latency inference workloads and resource constrained mobile platforms.
Fused operations also provide a higher level interface to define complex transformations like quantization, which would otherwise be infeasible or very hard to do at a more granular level.
LiteRT has many instances of fused operations for the reasons articulated above. These fused operations typically correspond to composite operations in the source TensorFlow program. Examples of composite operations in TensorFlow that are implemented as a single fused operation in LiteRT include various RNN operations like Unidirectional and Bidirectional sequence LSTM, convolution (conv2d, bias add, relu), fully connected (matmul, bias add, relu) and more. In LiteRT, LSTM quantization is currently only implemented in the fused LSTM operations.
Challenges with fused operations
Converting composite operations from TensorFlow to fused operations in LiteRT is a hard problem. This is because:
Composite operations are represented in the TensorFlow graph as a set of primitive operations without a well defined boundary. It can be very challenging to identify (e.g. via pattern matching) the sub-graph corresponding to such a composite operation.
There may be more than one TensorFlow implementation targeting a fused LiteRT operation. For example, there are many LSTM implementations in TensorFlow (Keras, Babelfish/lingvo etc) and each of these is composed of different primitive operations but they all could still be converted to the same fused LSTM operation in LiteRT.
As such, conversion of fused operations has proven quite challenging.
Converting from composite op to a TFLite custom operation (recommended)
Wrap the composite operation in a tf.function
In many cases, some part of the model can be mapped to a single operation in
TFLite. This can help with performance when writing an optimized implementation
for specific operations. To be able to create a fused operation in TFLite,
identify the part of the graph that represents a fused operation and wrap it in
a tf.function with
"experimental_implements" attribute to a tf.function
, which has attribute
value tfl_fusable_op
with value true
. If the custom operation takes
attributes then pass them as part of the same "experimental_implements".
Example,
def get_implements_signature():
implements_signature = [
# 'name' will be used as a name for the operation.
'name: "my_custom_fused_op"',
# attr "tfl_fusable_op" is required to be set with true value.
'attr {key: "tfl_fusable_op" value { b: true } }',
# Example attribute "example_option" that the op accepts.
'attr {key: "example_option" value { i: %d } }' % 10
]
return ' '.join(implements_signature)
@tf.function(experimental_implements=get_implements_signature())
def my_custom_fused_op(input_1, input_2):
# An empty function that represents pre/post processing example that
# is not represented as part of the Tensorflow graph.
output_1 = tf.constant(0.0, dtype=tf.float32, name='first_output')
output_2 = tf.constant(0.0, dtype=tf.float32, name='second_output')
return output_1, output_2
class TestModel(tf.Module):
def __init__(self):
super(TestModel, self).__init__()
self.conv_1 = tf.keras.layers.Conv2D(filters=1, kernel_size=(3, 3))
self.conv_2 = tf.keras.layers.Conv2D(filters=1, kernel_size=(3, 3))
@tf.function(input_signature=[
tf.TensorSpec(shape=[1, 28, 28, 3], dtype=tf.float32),
tf.TensorSpec(shape=[1, 28, 28, 3], dtype=tf.float32),
])
def simple_eval(self, input_a, input_b):
return my_custom_fused_op(self.conv_1(input_a), self.conv_2(input_b))
Note that you don't need to set allow_custom_ops
on the converter as
tfl_fusable_op
attribute imply this already.
Implement custom op and register with TFLite Interpreter
Implement your fused operation as a TFLite Custom operation - see instructions.
Note that, the name to register the op with should be similar to the name
specified in the name
attribute in the implements signature.
An example for the op in the example is
TfLiteRegistration reg = {};
// This name must match the name specified in the implements signature.
static constexpr char kOpName[] = "my_custom_fused_op";
reg.custom_name = kOpName;
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
// Add your code.
return kTfLiteOk;
};
reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
// Add your code.
return kTfLiteOk;
};
reg.builtin_code = kTfLiteCustom;
resolver->AddCustom(kOpName, ®);
Converting from composite to fused operation (Advanced)
The overall architecture for converting TensorFlow composite operations to LiteRT fused operations is below:
Wrap the composite operation in a tf.function
In the TensorFlow model source code, identify and abstract out the composite operation into a tf.function with the experimental_implements function annotation. See an example of embedding lookup. The function defines the interface and its arguments should be used to implement the conversion logic.
Write conversion code
The conversion code is written per the interface of the function with the
implements
annotation. See an example fusion for embedding
lookup. Conceptually, the conversion code replaces the composite
implementation of this interface with the fused one.
In the prepare-composite-functions pass, plugin in your conversion code.
In more advanced usages, it is possible to implement complex transformations of the composite operation's operands in order to derive the operands of the fused operation. See Keras LSTM. conversion code as an example.
Convert to LiteRT
Use the TFLiteConverter.from_saved_model API to convert to LiteRT.
Under the hood
We now describe high level details of the overall design in converting to fused operations in LiteRT.
Composing operations in TensorFlow
The use of tf.function with the experimental_implements function attribute allows users to explicitly compose new operations using TensorFlow primitive operations and specify the interface that the resultant composite operation implements. This is very useful as it provides:
- A well-defined boundary for the composite operation in the underlying TensorFlow graph.
- Explicitly specify the interface that this operation implements. The arguments of the tf.function correspond to the arguments of this interface.
As an example, let’s consider a composite operation defined to implement embedding lookup. This maps to a fused operation in LiteRT.
@tf.function(
experimental_implements="embedding_lookup")
def EmbFprop(embs, ids_vec):
"""Embedding forward prop.
Effectively, it computes:
num = size of ids_vec
rets = zeros([num, embedding dim])
for i in range(num):
rets[i, :] = embs[ids_vec[i], :]
return rets
Args:
embs: The embedding matrix.
ids_vec: A vector of int32 embedding ids.
Returns:
The result of embedding lookups. A matrix of shape
[num ids in ids_vec, embedding dims].
"""
num = tf.shape(ids_vec)[0]
rets = inplace_ops.empty([num] + emb_shape_suf, py_utils.FPropDtype(p))
def EmbFpropLoop(i, embs, ids_vec, rets):
# row_id = ids_vec[i]
row_id = tf.gather(ids_vec, i)
# row = embs[row_id]
row = tf.reshape(tf.gather(embs, row_id), [1] + emb_shape_suf)
# rets[i] = row
rets = inplace_ops.alias_inplace_update(rets, [i], row)
return embs, ids_vec, rets
_, _, rets = functional_ops.For(
start=0,
limit=num,
delta=1,
inputs=[embs, ids_vec, rets],
body=EmbFpropLoop,
rewrite_with_while=compiled)
if len(weight_shape) > 2:
rets = tf.reshape(rets, [num, symbolic.ToStatic(p.embedding_dim)])
return rets
By making models use composite operations via tf.function as illustrated above, it becomes possible to build a general infrastructure to identify and convert such operations to fused LiteRT operations.
Extending the LiteRT converter
The LiteRT converter that was released earlier this year only supported importing TensorFlow models as a graph with all variables replaced with their corresponding constant values. This does not work for operation fusion since such graphs have all functions inlined so that the variables can be turned into constants.
In order to leverage the
tf.function with the
experimental_implements
feature during the conversion process, the functions
need to be preserved until later in the conversion process.
As such, we implemented a new workflow of importing and converting TensorFlow models in the converter to support the composite operation fusion use case. Specifically, the new features added are:
- Importing TensorFlow saved models into MLIR
- fuse composite operations
- variable mutability analysis
- freeze all read-only variables
This allows us to perform operation fusion using the functions representing the composite operations prior to function inlining and variable freezing.
Implementing operation fusion
Let’s look at the operation fusion pass in more detail. This pass does the following:
- Loop through all functions in the MLIR module.
- If a function has the tf._implements attribute, based on the attribute value, calls the appropriate operation fusion utility.
- The operation fusion utility operates on the function’s operands and attributes (which serve as the interface for the conversion) and replaces the body of the function with an equivalent function body containing the fused operation.
- In many cases, the replaced body will contain operations other than the fused operation. These correspond to some static transforms on the function’s operands in order to obtain the operands of the fused operation. Since these computations can all be constant folded away, they would not be present in the exported flatbuffer where only the fused operation would exist.
Here is code snippet from the pass showing the main workflow:
void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
StringAttr attr) {
if (attr.getValue() == "embedding_lookup") {
func.eraseBody();
func.addEntryBlock();
// Convert the composite embedding_lookup function body to a
// TFLite fused embedding_lookup op.
ConvertEmbeddedLookupFunc convert_embedded_lookup(func);
if (failed(convert_embedded_lookup.VerifySignature())) {
return signalPassFailure();
}
convert_embedded_lookup.RewriteFunc();
} else if (attr.getValue() == mlir::TFL::kKerasLstm) {
func.eraseBody();
func.addEntryBlock();
OpBuilder builder(func.getBody());
if (failed(ConvertKerasLSTMLayer(func, &builder))) {
return signalPassFailure();
}
} else if (.....) /* Other fusions can plug in here */
}
Here is code snippet showing mapping this composite operation to a fused operation in LiteRT leveraging the function as a conversion interface.
void RewriteFunc() {
Value lookup = func_.getArgument(1);
Value value = func_.getArgument(0);
auto output_type = func_.getType().getResult(0);
OpBuilder builder(func_.getBody());
auto op = builder.create<mlir::TFL::EmbeddingLookupOp>(
func_.getLoc(), output_type, lookup, value);
builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResult());
}