This guide introduces you to the process of running a LiteRT (short for Lite Runtime) model on-device to make predictions based on input data. This is achieved with the LiteRT interpreter, which uses a static graph ordering and a custom (less-dynamic) memory allocator to ensure minimal load, initialization, and execution latency.
LiteRT inference typically follows the following steps:
Loading a model: load the
.tflite
model into memory, which contains the model's execution graph.Transforming data: Transform input data into the expected format and dimensions. Raw input data for the model generally does not match the input data format expected by the model. For example, you might need to resize an image or change the image format to be compatible with the model.
Running inference: Execute the LiteRT model to make predictions. This step involves using the LiteRT API to execute the model. It involves a few steps such as building the interpreter, and allocating tensors.
Interpreting output: Interpret the output tensors in a meaningful way that's useful in your application. For example, a model might return only a list of probabilities. It's up to you to map the probabilities to relevant categories and format the output.
This guide describes how to access the LiteRT interpreter and perform an inference using C++, Java, and Python.
Supported platforms
TensorFlow inference APIs are provided for most common mobile and embedded platforms such as Android, iOS and Linux, in multiple programming languages.
In most cases, the API design reflects a preference for performance over ease of use. LiteRT is designed for fast inference on small devices, so the APIs avoid unnecessary copies at the expense of convenience.
Across all libraries, the LiteRT API lets you to load models, feed inputs, and retrieve inference outputs.
Android Platform
On Android, LiteRT inference can be performed using either Java or C++ APIs. The Java APIs provide convenience and can be used directly within your Android Activity classes. The C++ APIs offer more flexibility and speed, but may require writing JNI wrappers to move data between Java and C++ layers.
See the C++ and Java sections for more information, or follow the Android quickstart.
iOS Platform
On iOS, LiteRT is available in Swift and Objective-C iOS libraries. You can also use C API directly in Objective-C code.
See the Swift, Objective-C, and C API sections, or follow the iOS quickstart.
Linux Platform
On Linux platforms, you can run inferences using LiteRT APIs available in C++.
Load and run a model
Loading and running a LiteRT model involves the following steps:
- Loading the model into memory.
- Building an
Interpreter
based on an existing model. - Setting input tensor values.
- Invoking inferences.
- Outputting tensor values.
Android (Java)
The Java API for running inferences with LiteRT is primarily designed for use
with Android, so it's available as an Android library dependency:
com.google.ai.edge.litert
.
In Java, you'll use the Interpreter
class to load a model and drive model
inference. In many cases, this may be the only API you need.
You can initialize an Interpreter
using a FlatBuffers (.tflite
) file:
public Interpreter(@NotNull File modelFile);
Or with a MappedByteBuffer
:
public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer);
In both cases, you must provide a valid LiteRT model or the API throws
IllegalArgumentException
. If you use MappedByteBuffer
to initialize an
Interpreter
, it must remain unchanged for the whole lifetime of the
Interpreter
.
The preferred way to run inference on a model is to use signatures - Available for models converted starting Tensorflow 2.5
try (Interpreter interpreter = new Interpreter(file_of_tensorflowlite_model)) {
Map<String, Object> inputs = new HashMap<>();
inputs.put("input_1", input1);
inputs.put("input_2", input2);
Map<String, Object> outputs = new HashMap<>();
outputs.put("output_1", output1);
interpreter.runSignature(inputs, outputs, "mySignature");
}
The runSignature
method takes three arguments:
Inputs : map for inputs from input name in the signature to an input object.
Outputs : map for output mapping from output name in signature to output data.
Signature Name (optional): Signature name (Can be left empty if the model has single signature).
Another way to run inferences when the model doesn't have a defined signatures.
Simply call Interpreter.run()
. For example:
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.run(input, output);
}
The run()
method takes only one input and returns only one output. So if your
model has multiple inputs or multiple outputs, instead use:
interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
In this case, each entry in inputs
corresponds to an input tensor and
map_of_indices_to_outputs
maps indices of output tensors to the corresponding
output data.
In both cases, the tensor indices should correspond to the values you gave to
the LiteRT Converter when you created the model. Be aware
that the order of tensors in input
must match the order given to the LiteRT
Converter.
The Interpreter
class also provides convenient functions for you to get the
index of any model input or output using an operation name:
public int getInputIndex(String opName);
public int getOutputIndex(String opName);
If opName
is not a valid operation in the model, it throws an
IllegalArgumentException
.
Also beware that Interpreter
owns resources. To avoid memory leak, the
resources must be released after use by:
interpreter.close();
For an example project with Java, see the Android object detection example app.
Supported data types
To use LiteRT, the data types of the input and output tensors must be one of the following primitive types:
float
int
long
byte
String
types are also supported, but they are encoded differently than the
primitive types. In particular, the shape of a string Tensor dictates the number
and arrangement of strings in the Tensor, with each element itself being a
variable length string. In this sense, the (byte) size of the Tensor cannot be
computed from the shape and type alone, and consequently strings cannot be
provided as a single, flat ByteBuffer
argument.
If other data types, including boxed types like Integer
and Float
, are used,
an IllegalArgumentException
will be thrown.
Inputs
Each input should be an array or multi-dimensional array of the supported
primitive types, or a raw ByteBuffer
of the appropriate size. If the input is
an array or multi-dimensional array, the associated input tensor will be
implicitly resized to the array's dimensions at inference time. If the input is
a ByteBuffer, the caller should first manually resize the associated input
tensor (via Interpreter.resizeInput()
) before running inference.
When using ByteBuffer
, prefer using direct byte buffers, as this allows the
Interpreter
to avoid unnecessary copies. If the ByteBuffer
is a direct byte
buffer, its order must be ByteOrder.nativeOrder()
. After it is used for a
model inference, it must remain unchanged until the model inference is finished.
Outputs
Each output should be an array or multi-dimensional array of the supported primitive types, or a ByteBuffer of the appropriate size. Note that some models have dynamic outputs, where the shape of output tensors can vary depending on the input. There's no straightforward way of handling this with the existing Java inference API, but planned extensions will make this possible.
iOS (Swift)
The Swift
API
is available in TensorFlowLiteSwift
Pod from Cocoapods.
First, you need to import TensorFlowLite
module.
import TensorFlowLite
// Getting model path
guard
let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite")
else {
// Error handling...
}
do {
// Initialize an interpreter with the model.
let interpreter = try Interpreter(modelPath: modelPath)
// Allocate memory for the model's input `Tensor`s.
try interpreter.allocateTensors()
let inputData: Data // Should be initialized
// input data preparation...
// Copy the input data to the input `Tensor`.
try self.interpreter.copy(inputData, toInputAt: 0)
// Run inference by invoking the `Interpreter`.
try self.interpreter.invoke()
// Get the output `Tensor`
let outputTensor = try self.interpreter.output(at: 0)
// Copy output to `Data` to process the inference results.
let outputSize = outputTensor.shape.dimensions.reduce(1, {x, y in x * y})
let outputData =
UnsafeMutableBufferPointer<Float32>.allocate(capacity: outputSize)
outputTensor.data.copyBytes(to: outputData)
if (error != nil) { /* Error handling... */ }
} catch error {
// Error handling...
}
iOS (Objective-C)
The Objective-C
API
is available in LiteRTObjC
Pod from Cocoapods.
First, you need to import TensorFlowLiteObjC
module.
@import TensorFlowLite;
NSString *modelPath = [[NSBundle mainBundle] pathForResource:@"model"
ofType:@"tflite"];
NSError *error;
// Initialize an interpreter with the model.
TFLInterpreter *interpreter = [[TFLInterpreter alloc] initWithModelPath:modelPath
error:&error];
if (error != nil) { /* Error handling... */ }
// Allocate memory for the model's input `TFLTensor`s.
[interpreter allocateTensorsWithError:&error];
if (error != nil) { /* Error handling... */ }
NSMutableData *inputData; // Should be initialized
// input data preparation...
// Get the input `TFLTensor`
TFLTensor *inputTensor = [interpreter inputTensorAtIndex:0 error:&error];
if (error != nil) { /* Error handling... */ }
// Copy the input data to the input `TFLTensor`.
[inputTensor copyData:inputData error:&error];
if (error != nil) { /* Error handling... */ }
// Run inference by invoking the `TFLInterpreter`.
[interpreter invokeWithError:&error];
if (error != nil) { /* Error handling... */ }
// Get the output `TFLTensor`
TFLTensor *outputTensor = [interpreter outputTensorAtIndex:0 error:&error];
if (error != nil) { /* Error handling... */ }
// Copy output to `NSData` to process the inference results.
NSData *outputData = [outputTensor dataWithError:&error];
if (error != nil) { /* Error handling... */ }
C API in Objective-C code
Objective-C API does not support delegates. In order to use delegates with Objective-C code, you need to directly call underlying C API.
#include "tensorflow/lite/c/c_api.h"
TfLiteModel* model = TfLiteModelCreateFromFile([modelPath UTF8String]);
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
// Create the interpreter.
TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
// Allocate tensors and populate the input tensor data.
TfLiteInterpreterAllocateTensors(interpreter);
TfLiteTensor* input_tensor =
TfLiteInterpreterGetInputTensor(interpreter, 0);
TfLiteTensorCopyFromBuffer(input_tensor, input.data(),
input.size() * sizeof(float));
// Execute inference.
TfLiteInterpreterInvoke(interpreter);
// Extract the output tensor data.
const TfLiteTensor* output_tensor =
TfLiteInterpreterGetOutputTensor(interpreter, 0);
TfLiteTensorCopyToBuffer(output_tensor, output.data(),
output.size() * sizeof(float));
// Dispose of the model and interpreter objects.
TfLiteInterpreterDelete(interpreter);
TfLiteInterpreterOptionsDelete(options);
TfLiteModelDelete(model);
C++
The C++ API for running inference with LiteRT is compatible with Android, iOS, and Linux platforms. The C++ API on iOS is only available when using bazel.
In C++, the model is stored in
FlatBufferModel
class.
It encapsulates a LiteRT model and you can build it in a couple of different
ways, depending on where the model is stored:
class FlatBufferModel {
// Build a model based on a file. Return a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromFile(
const char* filename,
ErrorReporter* error_reporter);
// Build a model based on a pre-loaded flatbuffer. The caller retains
// ownership of the buffer and should keep it alive until the returned object
// is destroyed. Return a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
const char* buffer,
size_t buffer_size,
ErrorReporter* error_reporter);
};
Now that you have the model as a FlatBufferModel
object, you can execute it
with an
Interpreter
.
A single FlatBufferModel
can be used simultaneously by more than one
Interpreter
.
The important parts of the Interpreter
API are shown in the code snippet
below. It should be noted that:
- Tensors are represented by integers, in order to avoid string comparisons (and any fixed dependency on string libraries).
- An interpreter must not be accessed from concurrent threads.
- Memory allocation for input and output tensors must be triggered by calling
AllocateTensors()
right after resizing tensors.
The simplest usage of LiteRT with C++ looks like this:
// Load the model
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(filename);
// Build the interpreter
tflite::ops::builtin::BuiltinOpResolver resolver;
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
// Resize input tensors, if needed.
interpreter->AllocateTensors();
float* input = interpreter->typed_input_tensor<float>(0);
// Fill `input`.
interpreter->Invoke();
float* output = interpreter->typed_output_tensor<float>(0);
For more example code, see
minimal.cc
and
label_image.cc
.
Python
The Python API for running inferences uses the
Interpreter
to load a model and
run inferences.
Install the LiteRT package:
$ python3 -m pip install ai-edge-litert
Import the LiteRT Interpreter
from ai_edge_litert.interpreter import Interpreter
Interpreter = Interpreter(model_path=args.model.file)
The following example shows how to use the Python interpreter to load a
FlatBuffers (.tflite
) file and run inference with random input data:
This example is recommended if you're converting from SavedModel with a defined SignatureDef.
class TestModel(tf.Module):
def __init__(self):
super(TestModel, self).__init__()
@tf.function(input_signature=[tf.TensorSpec(shape=[1, 10], dtype=tf.float32)])
def add(self, x):
'''
Simple method that accepts single input 'x' and returns 'x' + 4.
'''
# Name the output 'result' for convenience.
return {'result' : x + 4}
SAVED_MODEL_PATH = 'content/saved_models/test_variable'
TFLITE_FILE_PATH = 'content/test_variable.tflite'
# Save the model
module = TestModel()
# You can omit the signatures argument and a default signature name will be
# created with name 'serving_default'.
tf.saved_model.save(
module, SAVED_MODEL_PATH,
signatures={'my_signature':module.add.get_concrete_function()})
# Convert the model using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)
tflite_model = converter.convert()
with open(TFLITE_FILE_PATH, 'wb') as f:
f.write(tflite_model)
# Load the LiteRT model in LiteRT Interpreter
from ai_edge_litert.interpreter import Interpreter
interpreter = Interpreter(TFLITE_FILE_PATH)
# There is only 1 signature defined in the model,
# so it will return it by default.
# If there are multiple signatures then we can pass the name.
my_signature = interpreter.get_signature_runner()
# my_signature is callable with input as arguments.
output = my_signature(x=tf.constant([1.0], shape=(1,10), dtype=tf.float32))
# 'output' is dictionary with all outputs from the inference.
# In this case we have single output 'result'.
print(output['result'])
Another example if the model doesn't have SignatureDefs
defined.
import numpy as np
import tensorflow as tf
# Load the LiteRT model and allocate tensors.
from ai_edge_litert.interpreter import Interpreter
interpreter = Interpreter(TFLITE_FILE_PATH)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test the model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
As an alternative to loading the model as a pre-converted .tflite
file, you
can combine your code with the LiteRT
Compiler
, allowing you to convert your Keras model into the LiteRT format and then run
inference:
import numpy as np
import tensorflow as tf
img = tf.keras.Input(shape=(64, 64, 3), name="img")
const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
val = img + const
out = tf.identity(val, name="out")
# Convert to LiteRT format
converter = tf.lite.TFLiteConverter.from_keras_model(tf.keras.models.Model(inputs=[img], outputs=[out]))
tflite_model = converter.convert()
# Load the LiteRT model and allocate tensors.
from ai_edge_litert.interpreter import Interpreter
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
# Continue to get tensors and so forth, as shown above...
For more Python sample code, see
label_image.py
.
Run inference with dynamic shape model
If you want to run a model with dynamic input shape, resize the input shape
before running inference. Otherwise, the None
shape in Tensorflow models will
be replaced by a placeholder of 1
in LiteRT models.
The following examples show how to resize the input shape before running
inference in different languages. All the examples assume that the input shape
is defined as [1/None, 10]
, and need to be resized to [3, 10]
.
C++ example:
// Resize input tensors before allocate tensors
interpreter->ResizeInputTensor(/*tensor_index=*/0, std::vector<int>{3,10});
interpreter->AllocateTensors();
Python example:
# Load the LiteRT model in LiteRT Interpreter
from ai_edge_litert.interpreter import Interpreter
interpreter = Interpreter(model_path=TFLITE_FILE_PATH)
# Resize input shape for dynamic shape model and allocate tensor
interpreter.resize_tensor_input(interpreter.get_input_details()[0]['index'], [3, 10])
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()