Driver class to drive model inference with TensorFlow Lite.
Note: If you don't need access to any of the "experimental" API features below, prefer to use InterpreterApi and InterpreterFactory rather than using Interpreter directly.
A Interpreter
encapsulates a pre-trained TensorFlow Lite model, in which operations
are executed for model inference.
For example, if a model takes only one input and returns only one output:
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.run(input, output);
}
If a model takes multiple inputs or outputs:
Object[] inputs = {input0, input1, ...};
Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4); // Float tensor, shape 3x2x4.
ith_output.order(ByteOrder.nativeOrder());
map_of_indices_to_outputs.put(i, ith_output);
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
}
If a model takes or produces string tensors:
String[] input = {"foo", "bar"}; // Input tensor shape is [2].
String[][] output = new String[3][2]; // Output tensor shape is [3, 2].
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(input, output);
}
Note that there's a distinction between shape [] and shape[1]. For scalar string tensor outputs:
String[] input = {"foo"}; // Input tensor shape is [1].
ByteBuffer outputBuffer = ByteBuffer.allocate(OUTPUT_BYTES_SIZE); // Output tensor shape is [].
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(input, outputBuffer);
}
byte[] outputBytes = new byte[outputBuffer.remaining()];
outputBuffer.get(outputBytes);
// Below, the `charset` can be StandardCharsets.UTF_8.
String output = new String(outputBytes, charset);
Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite model with Toco, as are the default shapes of the inputs.
When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will
be implicitly resized according to that array's shape. When inputs are provided as Buffer
types, no implicit resizing is done; the caller must ensure that the Buffer
byte size
either matches that of the corresponding tensor, or that they first resize the tensor via resizeInput(int, int[])
. Tensor shape and type information can be obtained via the Tensor
class, available via getInputTensor(int)
and getOutputTensor(int)
.
WARNING:Interpreter
instances are not thread-safe. A Interpreter
owns resources that must be explicitly freed by invoking close()
The TFLite library is built against NDK API 19. It may work for Android API levels below 19, but is not guaranteed.
Nested Classes
class | Interpreter.Options | An options class for controlling runtime interpreter behavior. |
Public Constructors
Interpreter(File modelFile, Interpreter.Options options)
Initializes an
Interpreter and specifies options for customizing interpreter behavior. |
|
Interpreter(ByteBuffer byteBuffer, Interpreter.Options options)
Initializes an
Interpreter with a ByteBuffer of a model file and a set of
custom Interpreter.Options . |
Public Methods
void |
allocateTensors()
Explicitly updates allocations for all tensors, if necessary.
|
void |
close()
Release resources associated with the
InterpreterApi instance. |
int | |
Tensor |
getInputTensor(int inputIndex)
Gets the Tensor associated with the provided input index.
|
int |
getInputTensorCount()
Gets the number of input tensors.
|
Tensor |
getInputTensorFromSignature(String inputName, String signatureKey)
Gets the Tensor associated with the provided input name and signature method name.
|
Long |
getLastNativeInferenceDurationNanoseconds()
Returns native inference timing.
|
int | |
Tensor |
getOutputTensor(int outputIndex)
Gets the Tensor associated with the provided output index.
|
int |
getOutputTensorCount()
Gets the number of output Tensors.
|
Tensor |
getOutputTensorFromSignature(String outputName, String signatureKey)
Gets the Tensor associated with the provided output name in specific signature method.
|
String[] |
getSignatureInputs(String signatureKey)
Gets the list of SignatureDefs inputs for method
signatureKey . |
String[] |
getSignatureKeys()
Gets the list of SignatureDef exported method names available in the model.
|
String[] |