使用 TensorFlow Lite 驱动模型推断的驱动程序类。
注意:如果您不需要使用下面的任何“实验性”API 功能,可优先使用 ExplainerApi 和 ExplainerFactory,而不是直接使用 Explainer。
Interpreter
封装了预训练的 TensorFlow Lite 模型,在该模型中执行操作以进行模型推断。
例如,如果模型仅接受一个输入并且仅返回一个输出:
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.run(input, output);
}
如果模型接受多个输入或输出:
Object[] inputs = {input0, input1, ...};
Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4); // Float tensor, shape 3x2x4.
ith_output.order(ByteOrder.nativeOrder());
map_of_indices_to_outputs.put(i, ith_output);
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
}
如果模型获取或生成字符串张量:
String[] input = {"foo", "bar"}; // Input tensor shape is [2].
String[][] output = new String[3][2]; // Output tensor shape is [3, 2].
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(input, output);
}
请注意,形状 [] 和形状 [1] 是有区别的。对于标量字符串张量输出:
String[] input = {"foo"}; // Input tensor shape is [1].
ByteBuffer outputBuffer = ByteBuffer.allocate(OUTPUT_BYTES_SIZE); // Output tensor shape is [].
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(input, outputBuffer);
}
byte[] outputBytes = new byte[outputBuffer.remaining()];
outputBuffer.get(outputBytes);
// Below, the `charset` can be StandardCharsets.UTF_8.
String output = new String(outputBytes, charset);
使用 Toco 将 TensorFlow 模型转换为 TensorFlowLite 模型时,系统会确定输入和输出的顺序,以及输入的默认形状。
以(多维)数组形式提供输入时,相应的输入张量将根据该数组的形状隐式调整大小。以 Buffer
类型提供输入时,不会隐式调整大小;调用方必须确保 Buffer
字节大小与相应张量的大小一致,或者首先通过 resizeInput(int, int[])
调整张量的大小。您可以通过 Tensor
类(通过 getInputTensor(int)
和 getOutputTensor(int)
获得)获取张量形状和类型信息。
警告:Interpreter
实例不是线程安全的。Interpreter
拥有的资源必须通过调用 close()
来明确释放。
TFLite 库是基于 NDK API 19 构建的。它可能适用于低于 19 的 Android API 级别,但不能保证。
嵌套类
类别 | Interpreter.Options | 用于控制运行时解释器行为的选项类。 |
公共构造函数
公共方法
void |
allocateTensors()
根据需要显式更新所有张量的分配。
|
void |
close()
释放与
InterpreterApi 实例关联的资源。 |
整型 | |
张量 |
getInputTensor(int inputIndex)
获取与提供的输入索引相关联的张量。
|
整型 |
getInputTensorCount()
获取输入张量的数量。
|
张量 | |
长整型 |
getLastNativeInferenceDurationNanoseconds()
返回原生推断时间。
|
整型 | |
张量 |
getOutputTensor(int outputIndex)
获取与提供的输出索引相关联的张量。
|
整型 |
getOutputTensorCount()
获取输出张量的数量。
|
张量 | |
字符串 [] | |
字符串 [] |
getSignatureKeys()
获取模型中可用的 SignatureDef 导出方法名称的列表。
|
字符串 [] | |
void |
resetVariableTensors()
高级:将所有变量张量重置为默认值。
|
void |
resizeInput(int idx, int[] dims, boolean strict)
将原生模型的第 idx 输入的大小调整为指定的维度。
|
void |
resizeInput(int idx, int[] dims)
将原生模型的第 idx 输入的大小调整为指定的维度。
|
void | |
void | |
void |
runSignature(Map<String, Object> 输入, Map<String, Object> 输出)
与
runSignature(Map, Map, String) 相同,但不需要传递 signatureKey,假设模型有一个 SignatureDef。 |
void | |
void |
setCancelled(布尔值已取消)
高级:在调用
run(Object, Object) 时中断推断。 |
继承的方法
公共构造函数
public Explainer (File modelFile)
初始化 Interpreter
。
参数
modelFile | 预训练 TF Lite 模型的文件。 |
---|
抛出
IllegalArgumentException | 如果 modelFile 未对有效的 TensorFlow Lite 模型进行编码,则会发生该错误。
|
---|
public Explainer (File modelFile、Interpreter.Options options)
初始化 Interpreter
并指定用于自定义解释器行为的选项。
参数
modelFile | 这是一个预训练 TF Lite 模型的文件, |
---|---|
选项 | 一组用于自定义解释器行为的选项 |
抛出
IllegalArgumentException | 如果 modelFile 未对有效的 TensorFlow Lite 模型进行编码,则会发生该错误。
|
---|
public Explainer (ByteBuffer byteBuffer)
使用模型文件的 ByteBuffer
初始化 Interpreter
。
构建 Interpreter
之后,不应修改 ByteBuffer。ByteBuffer
可以是对模型文件进行内存映射的 MappedByteBuffer
,也可以是包含模型字节内容的 nativeOrder() 的直接 ByteBuffer
。
参数
byteBuffer |
---|
抛出
IllegalArgumentException | 如果 byteBuffer 不是 MappedByteBuffer ,也不是 nativeOrder 的直接 ByteBuffer 。
|
---|
public Explainer (ByteBuffer byteBuffer、Interpreter.Options 选项)
使用模型文件的 ByteBuffer
和一组自定义 Interpreter.Options
初始化 Interpreter
。
构建 Interpreter
之后,不应修改 ByteBuffer
。ByteBuffer
可以是对模型文件进行内存映射的 MappedByteBuffer
,也可以是包含模型字节内容的 nativeOrder() 的直接 ByteBuffer
。
参数
byteBuffer | |
---|---|
选项 |
抛出
IllegalArgumentException | 如果 byteBuffer 不是 MappedByteBuffer ,也不是 nativeOrder 的直接 ByteBuffer 。
|
---|
公共方法
public void allocateTensors ()
根据需要显式更新所有张量的分配。
这将使用给定的输入张量形状传播依赖张量的形状和内存分配。
注意:此调用 *完全可选*。如果调整了任何输入张量的大小,则在执行期间将自动分配张量。在执行图之前,此调用最有助于确定任何输出张量的形状,例如,
interpreter.resizeInput(0, new int[]{1, 4, 4, 3}));
interpreter.allocateTensors();
FloatBuffer input = FloatBuffer.allocate(interpreter.getInputTensor(0).numElements());
// Populate inputs...
FloatBuffer output = FloatBuffer.allocate(interpreter.getOutputTensor(0).numElements());
interpreter.run(input, output)
// Process outputs...
注意:某些图具有动态形状的输出,在这种情况下,在执行推断之前,输出形状可能不会完全传播。
public void close ()
释放与 InterpreterApi
实例关联的资源。
public int getInputTensorCount ()
获取输入张量的数量。
public Tensor getInputTensorFromSignature (String inputName, String signatureKey)
获取与提供的输入名称和签名方法名称相关联的 Tensor。
警告:这是一个实验性 API,可能会发生变化。
参数
inputName | 在签名中输入名称。 |
---|---|
signatureKey | 标识 SignatureDef 的签名密钥,如果模型有一个签名,可以为 null。 |
抛出
IllegalArgumentException | 如果 inputName 或 signatureKey 为 null 或为空,或提供的名称无效。
|
---|
public Tensor getOutputTensor (int outputIndex)
获取与所提供的输出索引相关联的张量。
注意:在执行推理之前,输出张量详细信息(例如形状)可能不会被完全填充。如果您在运行推理 *之前* 需要更新的详细信息(例如,在调整输入张量的大小之后,这可能会使输出张量形状失效),请使用 allocateTensors()
显式触发分配和形状传播。请注意,对于输出形状依赖于输入 *值*的图表,只有在运行推理之前,可能无法完全确定输出形状。
参数
outputIndex |
---|
public int getOutputTensorCount ()
获取输出张量的数量。
public Tensor getOutputTensorFromSignature (String outputName, String signatureKey)
获取与特定签名方法中提供的输出名称相关联的张量。
注意:在执行推理之前,输出张量详细信息(例如形状)可能不会被完全填充。如果您在运行推理 *之前* 需要更新的详细信息(例如,在调整输入张量的大小之后,这可能会使输出张量形状失效),请使用 allocateTensors()
显式触发分配和形状传播。请注意,对于输出形状依赖于输入 *值*的图表,只有在运行推理之前,可能无法完全确定输出形状。
警告:这是一个实验性 API,可能会发生变化。
参数
outputName | 签名中的输出名称。 |
---|---|
signatureKey | 标识 SignatureDef 的签名密钥,如果模型有一个签名,可以为 null。 |
抛出
IllegalArgumentException | 如果 outputName 或 signatureKey 为 null 或为空,或者提供的名称无效。
|
---|
public String[] getSignatureInputs (String signatureKey)
获取方法 signatureKey
的 SignatureDefs 输入列表。
警告:这是一个实验性 API,可能会发生变化。
参数
signatureKey |
---|
public String[] getSignatureOutputs (String signatureKey)
获取方法 signatureKey
的 SignatureDefs 输出列表。
警告:这是一个实验性 API,可能会发生变化。
参数
signatureKey |
---|
public void resetVariableTensors ()
高级:将所有变量张量重置为默认值。
如果可变张量没有关联的缓冲区,则会重置为零。
警告:这是一个实验性 API,可能会发生变化。
public void resizeInput (int idx, int[] dims, boolean strict)
将原生模型的第 idx 输入的大小调整为指定的维度。
当 `strict` 为 True 时,只能调整未知尺寸的大小。在“Tensor.shapeSignature()”返回的数组中,未知维度用“-1”表示。
参数
idx | |
---|---|
dims | |
严格 |
public void resizeInput (int idx, int[] dims)
将原生模型的第 idx 输入的大小调整为指定的维度。
参数
idx | |
---|---|
dims |
public void run (Object 输入、Object 输出)
如果模型只接受一个输入且仅提供一个输出,则运行模型推断。
警告:如果将 Buffer
(最好是直接,但不是必需的)用作输入/输出数据类型,则 API 的效率会更高。请考虑使用 Buffer
来馈送和提取原始数据,以获得更好的效果。系统支持以下具体的 Buffer
类型:
ByteBuffer
- 与任何底层基元张量类型都兼容。FloatBuffer
- 与浮点张量兼容。IntBuffer
- 与 int32 Tensor 兼容。LongBuffer
- 与 int64 Tensor 兼容。
Buffer
)或标量输入。参数
输入 | 数组、多维数组或基元类型(包括 int、float、long 和 byte)的 Buffer 。Buffer 是为基元类型传递大型输入数据的首选方式,而字符串类型需要使用(多维)数组输入路径。使用 Buffer 时,其内容应保持不变,直到模型推断完成,并且调用方必须确保 Buffer 位于适当的读取位置。仅当调用方使用允许缓冲区句柄互操作的 Delegate 且此类缓冲区已绑定到输入 Tensor 时,才允许 null 值。 |
---|---|
output | 输出数据的多维数组或基元类型(包括 int、float、long 和 byte)的 Buffer 。使用 Buffer 时,调用方必须确保设置了适当的写入位置。允许使用 null 值,但在某些情况下很有用,例如,如果调用方使用允许缓冲区句柄互操作的 Delegate ,并且此类缓冲区已绑定到输出 Tensor (另请参阅 Interpreter.Options#setAllowBufferHandleOutput(boolean)),;或者如果图表具有动态形状的输出,并且调用方必须在调用输出后查询输出 Tensor 形状,则直接从 Tensor.asReadOnlyBuffer() 提取数据。 |
public void runForMultipleInputsOutputs (Object[] input, Map<Integer, Object> 输出)
如果模型接受多个输入或返回多个输出,则运行模型推断。
警告:如果将 Buffer
(最好是直接,但不是必需的)用作输入/输出数据类型,则 API 的效率会更高。请考虑使用 Buffer
来馈送和提取原始数据,以获得更好的效果。系统支持以下具体的 Buffer
类型:
ByteBuffer
- 与任何底层基元张量类型都兼容。FloatBuffer
- 与浮点张量兼容。IntBuffer
- 与 int32 Tensor 兼容。LongBuffer
- 与 int64 Tensor 兼容。
Buffer
)或标量输入。
注意:仅当调用方使用允许缓冲区句柄互操作操作的 Delegate
,并且此类缓冲区已绑定到相应的输入或输出 Tensor
时,才允许 inputs
和 outputs
的各个元素的 null
值。
参数
输入 | 输入数据的数组。输入的顺序应与模型输入的顺序相同。每个输入可以是数组或多维数组,也可以是基元类型(包括 int、float、long 和 byte)的 Buffer 。Buffer 是传递大量输入数据的首选方式,而字符串类型需要使用(多维)数组输入路径。使用 Buffer 时,其内容应保持不变,直到模型推断完成,并且调用方必须确保 Buffer 位于适当的读取位置。 |
---|---|
输出 | 将输出索引映射到输出数据的多维数组或基元类型(包括 int、float、long 和 byte)的 Buffer 的映射。它只需保留相应条目即可使用输出。使用 Buffer 时,调用方必须确保设置了适当的写入位置。在以下情况下,映射可能为空:缓冲区句柄用于输出张量数据;输出为动态形状,且调用方必须在调用推断后查询输出 Tensor 形状,从而直接从输出张量提取数据(通过 Tensor.asReadOnlyBuffer() )。 |
public void runSignature (Map<String, Object> 输入, Map<String, Object> 输出)
与 runSignature(Map, Map, String)
相同,但不需要传递 signatureKey,假设模型有一个 SignatureDef。如果模型有多个 SignatureDef,它将抛出异常。
警告:这是一个实验性 API,可能会发生变化。
参数
输入 | |
---|---|
输出 |
public void runSignature (Map<String, Object> 输入, Map<String, Object> 输出, String 键)
根据通过 signatureKey
提供的 SignatureDef 运行模型推断。
如需详细了解允许的输入和输出数据类型,请参阅 run(Object, Object)
。
警告:这是一个实验性 API,可能会发生变化。
参数
输入 | 从 SignatureDef 中的输入名称到输入对象的映射。 |
---|---|
输出 | 从 SignatureDef 中的输出名称到输出数据的映射。如果调用方希望在推理后直接查询 Tensor 数据(例如,如果输出形状是动态的,或者使用了输出缓冲区句柄),则此字段可能为空。 |
signatureKey | 标识 SignatureDef 的签名密钥。 |
抛出
IllegalArgumentException | 如果 inputs 为 null 或为空,如果 outputs 或 signatureKey 为 null,或者运行推断时发生错误。 |
---|
public void setCancelled (boolean cancelled)
高级:在调用 run(Object, Object)
时中断推断。
调用此函数时,取消标记将设置为 true。解释器将在 Op 调用之间检查该标记;如果该值为 true
,则解释器将停止执行。在 setCancelled(false)
明确“取消取消”之前,解释器将保持取消状态。
警告:这是一个实验性 API,可能会发生变化。
参数
已取消 | true 表示尽最大努力取消推断;false 表示继续。 |
---|
抛出
IllegalStateException | 如果未使用可取消选项(默认情况下处于关闭状态)初始化解释器。 |
---|