Interpreter

公共最终类 口译

驱动程序类,通过 TensorFlow Lite 推动模型推断。

注意:如果您不需要访问任何“实验性”以下 API 功能,建议使用 InterpreterApi 和 InterpreterFactory,而不是直接使用 Interpreter。

Interpreter 封装了预训练的 TensorFlow Lite 模型,其中操作 以进行模型推断。

例如,如果模型仅接受一个输入并仅返回一个输出:

try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
   interpreter.run(input, output);
 }
 

如果模型接受多个输入或输出:

Object[] inputs = {input0, input1, ...};
 Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
 FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4);  // Float tensor, shape 3x2x4.
 ith_output.order(ByteOrder.nativeOrder());
 map_of_indices_to_outputs.put(i, ith_output);
 try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
   interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
 }
 

如果模型采用或生成字符串张量:

String[] input = {"foo", "bar"};  // Input tensor shape is [2].
 String[][] output = new String[3][2];  // Output tensor shape is [3, 2].
 try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
   interpreter.runForMultipleInputsOutputs(input, output);
 }
 

请注意,形状 [] 和形状 [1] 是有区别的。对于标量字符串张量 输出:

String[] input = {"foo"};  // Input tensor shape is [1].
 ByteBuffer outputBuffer = ByteBuffer.allocate(OUTPUT_BYTES_SIZE);  // Output tensor shape is [].
 try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
   interpreter.runForMultipleInputsOutputs(input, outputBuffer);
 }
 byte[] outputBytes = new byte[outputBuffer.remaining()];
 outputBuffer.get(outputBytes);
 // Below, the `charset` can be StandardCharsets.UTF_8.
 String output = new String(outputBytes, charset);
 

将 TensorFlow 模型转换为 TensorFlowLite 时确定输入和输出的顺序 以及输入的默认形状。

当输入作为(多维)数组提供时,对应的输入张量将 根据该数组的形状隐式调整大小。当输入作为 Buffer 提供时 因此不会进行隐式大小调整调用方必须确保 Buffer 字节大小 要么匹配相应张量的大小,要么先通过 resizeInput(int, int[]) 调整张量的大小。张量形状和类型信息可通过 Tensor 类获取,可通过 getInputTensor(int)getOutputTensor(int) 获取。

警告Interpreter 实例不是线程安全的。Interpreter 拥有所有必须通过调用 close() 明确释放的资源。

TFLite 库是针对 NDK API 19 构建的。它可能适用于低于 Android API 级别 19、 但不保证一定如此

嵌套类

类别 Interpreter.Options 用于控制运行时解释器行为的选项类。

公共构造函数

Interpreter(文件 modelFile)
初始化 Interpreter
InterpreterFile modelFile、Interpreter.Options 选项)
初始化 Interpreter 并指定用于自定义解释器行为的选项。
Interpreter(ByteBuffer byteBuffer)
使用模型文件的 ByteBuffer 初始化 Interpreter
InterpreterByteBuffer byteBuffer、Interpreter.Options 选项)
使用模型文件的 ByteBuffer 和一组Interpreter 自定义 Interpreter.Options

公共方法

无效
allocateTensors()
如有必要,显式更新所有张量的分配。
无效
close()
释放与 InterpreterApi 实例关联的资源。
整数
getInputIndex(String opName)
根据输入的操作名称获取该输入的索引。
Tensor
getInputTensor(int inputIndex)
获取与提供的输入索引关联的张量。
整数
getInputTensorCount()
获取输入张量的数量。
Tensor
getInputTensorFromSignature(String inputName, String特定签名 Key)
获取与提供的输入名称和签名方法名称相关联的张量。
getLastNativeInferenceDurationNanoseconds()
返回原生推理时间。
整数
getOutputIndex(String opName)
根据输出的操作名称,获取该输出的索引。
Tensor
getOutputTensor(int outputIndex)
获取与提供的输出索引关联的张量。
整数
getOutputTensorCount()
获取输出张量的数量。
Tensor
getOutputTensorFromSignature(String outputName, String signatureKey)
获取与特定签名方法中提供的输出名称相关联的张量。
String[]
getSignatureInputs(String signedKey)
获取方法 signatureKey 的 SignatureDefs 输入列表。