Interpreter

public final class Translationer

使用 TensorFlow Lite 驱动模型推断的驱动程序类。

注意:如果您不需要使用下面的任何“实验性”API 功能,可优先使用 ExplainerApi 和 ExplainerFactory,而不是直接使用 Explainer。

Interpreter 封装了预训练的 TensorFlow Lite 模型,在该模型中执行操作以进行模型推断。

例如,如果模型仅接受一个输入并且仅返回一个输出:

try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
   interpreter.run(input, output);
 }
 

如果模型接受多个输入或输出:

Object[] inputs = {input0, input1, ...};
 Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
 FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4);  // Float tensor, shape 3x2x4.
 ith_output.order(ByteOrder.nativeOrder());
 map_of_indices_to_outputs.put(i, ith_output);
 try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
   interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
 }
 

如果模型获取或生成字符串张量:

String[] input = {"foo", "bar"};  // Input tensor shape is [2].
 String[][] output = new String[3][2];  // Output tensor shape is [3, 2].
 try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
   interpreter.runForMultipleInputsOutputs(input, output);
 }
 

请注意,形状 [] 和形状 [1] 是有区别的。对于标量字符串张量输出:

String[] input = {"foo"};  // Input tensor shape is [1].
 ByteBuffer outputBuffer = ByteBuffer.allocate(OUTPUT_BYTES_SIZE);  // Output tensor shape is [].
 try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
   interpreter.runForMultipleInputsOutputs(input, outputBuffer);
 }
 byte[] outputBytes = new byte[outputBuffer.remaining()];
 outputBuffer.get(outputBytes);
 // Below, the `charset` can be StandardCharsets.UTF_8.
 String output = new String(outputBytes, charset);
 

使用 Toco 将 TensorFlow 模型转换为 TensorFlowLite 模型时,系统会确定输入和输出的顺序,以及输入的默认形状。

以(多维)数组形式提供输入时,相应的输入张量将根据该数组的形状隐式调整大小。以 Buffer 类型提供输入时,不会隐式调整大小;调用方必须确保 Buffer 字节大小与相应张量的大小一致,或者首先通过 resizeInput(int, int[]) 调整张量的大小。您可以通过 Tensor 类(通过 getInputTensor(int)getOutputTensor(int) 获得)获取张量形状和类型信息。

警告Interpreter 实例不是线程安全的。Interpreter 拥有的资源必须通过调用 close() 来明确释放。

TFLite 库是基于 NDK API 19 构建的。它可能适用于低于 19 的 Android API 级别,但不能保证。

嵌套类

类别 Interpreter.Options 用于控制运行时解释器行为的选项类。

公共构造函数

ExplainerFile modelFile)
初始化 Interpreter
ExplainerFile modelFile、Interpreter.Options 选项)
初始化 Interpreter 并指定用于自定义解释器行为的选项。
Explainer(ByteBuffer byteBuffer)
使用模型文件的 ByteBuffer 初始化 Interpreter
ExplainerByteBuffer byteBuffer、Interpreter.Options 选项)
使用模型文件的 ByteBuffer 和一组自定义 Interpreter.Options 初始化 Interpreter

公共方法

void
allocateTensors()
根据需要显式更新所有张量的分配。
void
close()
释放与 InterpreterApi 实例关联的资源。
整型
getInputIndex(String opName)
根据输入的操作名称获取输入的索引。
张量
getInputTensor(int inputIndex)
获取与提供的输入索引相关联的张量。
整型
getInputTensorCount()
获取输入张量的数量。
张量
getInputTensorFromSignature(String inputName, String signatureKey)
获取与提供的输入名称和签名方法名称相关联的 Tensor。
长整型
getLastNativeInferenceDurationNanoseconds()
返回原生推断时间。
整型
getOutputIndex(String opName)
根据输出的操作名称获取输出的索引。
张量
getOutputTensor(int outputIndex)
获取与提供的输出索引相关联的张量。
整型
getOutputTensorCount()
获取输出张量的数量。
张量
getOutputTensorFromSignature(String outputName, String signatureKey)
获取与特定签名方法中提供的输出名称相关联的 Tensor。
字符串 []
getSignatureInputs(String signatureKey)
获取方法 signatureKey 的 SignatureDefs 输入列表。
字符串 []
getSignatureKeys()
获取模型中可用的 SignatureDef 导出方法名称的列表。
字符串 []
getSignatureOutputs(String signatureKey)
获取 signatureKey 方法的 SignatureDefs 输出列表。
void
resetVariableTensors()
高级:将所有变量张量重置为默认值。
void
resizeInput(int idx, int[] dims, boolean strict)
将原生模型的第 idx 输入的大小调整为指定的维度。
void
resizeInput(int idx, int[] dims)
将原生模型的第 idx 输入的大小调整为指定的维度。
void
runObject 输入、Object 输出)
如果模型只接受一个输入且仅提供一个输出,则运行模型推断。
void
runForMultipleInputsOutputs(Object[] 输入、Map<Integer, Object> 输出)
如果模型接受多个输入或返回多个输出,则运行模型推断。
void
runSignature(Map<StringObject> 输入, Map<StringObject> 输出)
runSignature(Map, Map, String) 相同,但不需要传递 signatureKey,假设模型有一个 SignatureDef。
void
runSignature(Map<StringObject> 输入, Map<StringObject> 输出, String signatureKey)
根据通过 signatureKey 提供的 SignatureDef 运行模型推断。
void
setCancelled(布尔值已取消)
高级:在调用 run(Object, Object) 时中断推断。

继承的方法

公共构造函数

public Explainer (File modelFile)

初始化 Interpreter

参数
modelFile 预训练 TF Lite 模型的文件。
抛出
IllegalArgumentException 如果 modelFile 未对有效的 TensorFlow Lite 模型进行编码,则会发生该错误。

public Explainer File modelFile、Interpreter.Options options)

初始化 Interpreter 并指定用于自定义解释器行为的选项。

参数
modelFile 这是一个预训练 TF Lite 模型的文件,
选项 一组用于自定义解释器行为的选项
抛出
IllegalArgumentException 如果 modelFile 未对有效的 TensorFlow Lite 模型进行编码,则会发生该错误。

public Explainer (ByteBuffer byteBuffer)

使用模型文件的 ByteBuffer 初始化 Interpreter

构建 Interpreter 之后,不应修改 ByteBuffer。ByteBuffer 可以是对模型文件进行内存映射的 MappedByteBuffer,也可以是包含模型字节内容的 nativeOrder() 的直接 ByteBuffer

参数
byteBuffer
抛出
IllegalArgumentException 如果 byteBuffer 不是 MappedByteBuffer,也不是 nativeOrder 的直接 ByteBuffer

public Explainer ByteBuffer byteBuffer、Interpreter.Options 选项)

使用模型文件的 ByteBuffer 和一组自定义 Interpreter.Options 初始化 Interpreter

构建 Interpreter 之后,不应修改 ByteBufferByteBuffer 可以是对模型文件进行内存映射的 MappedByteBuffer,也可以是包含模型字节内容的 nativeOrder() 的直接 ByteBuffer

参数
byteBuffer
选项
抛出
IllegalArgumentException 如果 byteBuffer 不是 MappedByteBuffer,也不是 nativeOrder 的直接 ByteBuffer

公共方法

public void allocateTensors ()

根据需要显式更新所有张量的分配。

这将使用给定的输入张量形状传播依赖张量的形状和内存分配。

注意:此调用 *完全可选*。如果调整了任何输入张量的大小,则在执行期间将自动分配张量。在执行图之前,此调用最有助于确定任何输出张量的形状,例如,

 interpreter.resizeInput(0, new int[]{1, 4, 4, 3}));
 interpreter.allocateTensors();
 FloatBuffer input = FloatBuffer.allocate(interpreter.getInputTensor(0).numElements());
 // Populate inputs...
 FloatBuffer output = FloatBuffer.allocate(interpreter.getOutputTensor(0).numElements());
 interpreter.run(input, output)
 // Process outputs...

注意:某些图具有动态形状的输出,在这种情况下,在执行推断之前,输出形状可能不会完全传播。

public void close ()

释放与 InterpreterApi 实例关联的资源。

public int getInputIndex (String opName)

根据输入的操作名称获取输入的索引。

参数
opName

public Tensor getInputTensor (int inputIndex)

获取与提供的输入索引相关联的张量。

参数
inputIndex

public int getInputTensorCount ()

获取输入张量的数量。

public Tensor getInputTensorFromSignature (String inputName, String signatureKey)

获取与提供的输入名称和签名方法名称相关联的 Tensor。

警告:这是一个实验性 API,可能会发生变化。

参数
inputName 在签名中输入名称。
signatureKey 标识 SignatureDef 的签名密钥,如果模型有一个签名,可以为 null。
抛出
IllegalArgumentException 如果 inputNamesignatureKey 为 null 或为空,或提供的名称无效。

public Long getLastNativeInferenceDurationNanoseconds ()

返回原生推断时间。

public int getOutputIndex (String opName)

根据输出的操作名称获取输出的索引。

参数
opName

public Tensor getOutputTensor (int outputIndex)

获取与所提供的输出索引相关联的张量。

注意:在执行推理之前,输出张量详细信息(例如形状)可能不会被完全填充。如果您在运行推理 *之前* 需要更新的详细信息(例如,在调整输入张量的大小之后,这可能会使输出张量形状失效),请使用 allocateTensors() 显式触发分配和形状传播。请注意,对于输出形状依赖于输入 *值*的图表,只有在运行推理之前,可能无法完全确定输出形状。

参数
outputIndex

public int getOutputTensorCount ()

获取输出张量的数量。

public Tensor getOutputTensorFromSignature (String outputName, String signatureKey)

获取与特定签名方法中提供的输出名称相关联的张量。

注意:在执行推理之前,输出张量详细信息(例如形状)可能不会被完全填充。如果您在运行推理 *之前* 需要更新的详细信息(例如,在调整输入张量的大小之后,这可能会使输出张量形状失效),请使用 allocateTensors() 显式触发分配和形状传播。请注意,对于输出形状依赖于输入 *值*的图表,只有在运行推理之前,可能无法完全确定输出形状。

警告:这是一个实验性 API,可能会发生变化。

参数
outputName 签名中的输出名称。
signatureKey 标识 SignatureDef 的签名密钥,如果模型有一个签名,可以为 null。
抛出
IllegalArgumentException 如果 outputNamesignatureKey 为 null 或为空,或者提供的名称无效。

public String[] getSignatureInputs (String signatureKey)

获取方法 signatureKey 的 SignatureDefs 输入列表。

警告:这是一个实验性 API,可能会发生变化。

参数
signatureKey

public String[] getSignatureKeys ()

获取模型中可用的 SignatureDef 导出方法名称的列表。

警告:这是一个实验性 API,可能会发生变化。

public String[] getSignatureOutputs (String signatureKey)

获取方法 signatureKey 的 SignatureDefs 输出列表。

警告:这是一个实验性 API,可能会发生变化。

参数
signatureKey

public void resetVariableTensors ()

高级:将所有变量张量重置为默认值。

如果可变张量没有关联的缓冲区,则会重置为零。

警告:这是一个实验性 API,可能会发生变化。

public void resizeInput (int idx, int[] dims, boolean strict)

将原生模型的第 idx 输入的大小调整为指定的维度。

当 `strict` 为 True 时,只能调整未知尺寸的大小。在“Tensor.shapeSignature()”返回的数组中,未知维度用“-1”表示。

参数
idx
dims
严格

public void resizeInput (int idx, int[] dims)

将原生模型的第 idx 输入的大小调整为指定的维度。

参数
idx
dims

public void run Object 输入、Object 输出)

如果模型只接受一个输入且仅提供一个输出,则运行模型推断。

警告:如果将 Buffer(最好是直接,但不是必需的)用作输入/输出数据类型,则 API 的效率会更高。请考虑使用 Buffer 来馈送和提取原始数据,以获得更好的效果。系统支持以下具体的 Buffer 类型:

  • ByteBuffer - 与任何底层基元张量类型都兼容。
  • FloatBuffer - 与浮点张量兼容。
  • IntBuffer - 与 int32 Tensor 兼容。
  • LongBuffer - 与 int64 Tensor 兼容。
请注意,布尔值类型仅支持作为数组(而不是 Buffer)或标量输入。

参数
输入 数组、多维数组或基元类型(包括 int、float、long 和 byte)的 BufferBuffer 是为基元类型传递大型输入数据的首选方式,而字符串类型需要使用(多维)数组输入路径。使用 Buffer 时,其内容应保持不变,直到模型推断完成,并且调用方必须确保 Buffer 位于适当的读取位置。仅当调用方使用允许缓冲区句柄互操作的 Delegate 且此类缓冲区已绑定到输入 Tensor 时,才允许 null 值。
output 输出数据的多维数组或基元类型(包括 int、float、long 和 byte)的 Buffer。使用 Buffer 时,调用方必须确保设置了适当的写入位置。允许使用 null 值,但在某些情况下很有用,例如,如果调用方使用允许缓冲区句柄互操作的 Delegate,并且此类缓冲区已绑定到输出 Tensor(另请参阅 Interpreter.Options#setAllowBufferHandleOutput(boolean)),;或者如果图表具有动态形状的输出,并且调用方必须在调用输出后查询输出 Tensor 形状,则直接从 Tensor.asReadOnlyBuffer() 提取数据。

public void runForMultipleInputsOutputs (Object[] input, Map<IntegerObject> 输出)

如果模型接受多个输入或返回多个输出,则运行模型推断。

警告:如果将 Buffer(最好是直接,但不是必需的)用作输入/输出数据类型,则 API 的效率会更高。请考虑使用 Buffer 来馈送和提取原始数据,以获得更好的效果。系统支持以下具体的 Buffer 类型:

  • ByteBuffer - 与任何底层基元张量类型都兼容。
  • FloatBuffer - 与浮点张量兼容。
  • IntBuffer - 与 int32 Tensor 兼容。
  • LongBuffer - 与 int64 Tensor 兼容。
请注意,布尔值类型仅支持作为数组(而不是 Buffer)或标量输入。

注意:仅当调用方使用允许缓冲区句柄互操作操作的 Delegate,并且此类缓冲区已绑定到相应的输入或输出 Tensor 时,才允许 inputsoutputs 的各个元素的 null 值。

参数
输入 输入数据的数组。输入的顺序应与模型输入的顺序相同。每个输入可以是数组或多维数组,也可以是基元类型(包括 int、float、long 和 byte)的 BufferBuffer 是传递大量输入数据的首选方式,而字符串类型需要使用(多维)数组输入路径。使用 Buffer 时,其内容应保持不变,直到模型推断完成,并且调用方必须确保 Buffer 位于适当的读取位置。
输出 将输出索引映射到输出数据的多维数组或基元类型(包括 int、float、long 和 byte)的 Buffer 的映射。它只需保留相应条目即可使用输出。使用 Buffer 时,调用方必须确保设置了适当的写入位置。在以下情况下,映射可能为空:缓冲区句柄用于输出张量数据;输出为动态形状,且调用方必须在调用推断后查询输出 Tensor 形状,从而直接从输出张量提取数据(通过 Tensor.asReadOnlyBuffer())。

public void runSignature (Map<StringObject> 输入, Map<StringObject> 输出)

runSignature(Map, Map, String) 相同,但不需要传递 signatureKey,假设模型有一个 SignatureDef。如果模型有多个 SignatureDef,它将抛出异常。

警告:这是一个实验性 API,可能会发生变化。

参数
输入
输出

public void runSignature (Map<StringObject> 输入, Map<StringObject> 输出, String 键)

根据通过 signatureKey 提供的 SignatureDef 运行模型推断。

如需详细了解允许的输入和输出数据类型,请参阅 run(Object, Object)

警告:这是一个实验性 API,可能会发生变化。

参数
输入 从 SignatureDef 中的输入名称到输入对象的映射。
输出 从 SignatureDef 中的输出名称到输出数据的映射。如果调用方希望在推理后直接查询 Tensor 数据(例如,如果输出形状是动态的,或者使用了输出缓冲区句柄),则此字段可能为空。
signatureKey 标识 SignatureDef 的签名密钥。
抛出
IllegalArgumentException 如果 inputs 为 null 或为空,如果 outputssignatureKey 为 null,或者运行推断时发生错误。

public void setCancelled (boolean cancelled)

高级:在调用 run(Object, Object) 时中断推断。

调用此函数时,取消标记将设置为 true。解释器将在 Op 调用之间检查该标记;如果该值为 true,则解释器将停止执行。在 setCancelled(false) 明确“取消取消”之前,解释器将保持取消状态。

警告:这是一个实验性 API,可能会发生变化。

参数
已取消 true 表示尽最大努力取消推断;false 表示继续。
抛出
IllegalStateException 如果未使用可取消选项(默认情况下处于关闭状态)初始化解释器。