在 Android 上使用 Kotlin 运行 LiteRT 编译模型 API

LiteRT 编译模型 API 以 Kotlin 语言提供,可为 Android 开发者提供无缝的加速器优先体验,并支持高级 API。

如需查看 Kotlin 中的 LiteRT 应用示例,请参阅使用 Kotlin 进行图片分割演示

开始使用

按照以下步骤将 LiteRT Compiled Model API 添加到您的 Android 应用。

添加 Maven 软件包

将包含已编译模型 API 的 LiteRT 依赖项添加到您的应用:

dependencies {
  ...
  implementation `com.google.ai.edge.litert:litert:2.0.0-alpha`
}

创建已编译的模型

使用 CompiledModel API,通过模型和您选择的硬件加速来初始化运行时:

val  model =
  CompiledModel.create(
    context.assets,
    "mymodel.tflite",
    CompiledModel.Options(Accelerator.CPU),
    env,
  )

创建输入和输出缓冲区

创建必要的数据结构(缓冲区),以保存您将馈送到模型中进行推理的输入数据,以及模型在运行推理后生成的输出数据。

val inputBuffers = model.createInputBuffers()
val outputBuffers = model.createOutputBuffers()

如果您使用的是 CPU 内存,请通过将数据直接写入第一个输入缓冲区来填充输入。

inputBuffers[0].writeFloat(FloatArray(data_size) { data_value /* your data */ })

调用模型

提供输入和输出缓冲区,运行编译后的模型。

model.run(inputBuffers, outputBuffers)

检索输出

通过直接从内存中读取模型输出来检索输出。

val outputFloatArray = outputBuffers[0].readFloat()

主要概念和组件

如需了解 LiteRT Kotlin 编译模型 API 的主要概念和组件,请参阅以下部分。

基本推理(CPU)

以下是使用 LiteRT Next 进行推理的精简版实现。

// Load model and initialize runtime
val  model =
    CompiledModel.create(
        context.assets,
        "mymodel.tflite"
    )

// Preallocate input/output buffers
val inputBuffers = model.createInputBuffers()
val outputBuffers = model.createOutputBuffers()

// Fill the first input
inputBuffers[0].writeFloat(FloatArray(data_size) { data_value /* your data */ })

// Invoke
model.run(inputBuffers, outputBuffers)

// Read the output
val outputFloatArray = outputBuffers[0].readFloat()

// Clean up buffers and model
inputBuffers.forEach { it.close() }
outputBuffers.forEach { it.close() }
model.close()

已编译的模型 (CompiledModel)

编译后的模型 API (CompiledModel) 负责加载模型、应用硬件加速、实例化运行时、创建输入和输出缓冲区,以及运行推理。

以下简化版代码段演示了编译模型 API 如何获取 LiteRT 模型 (.tflite) 并创建可用于运行推理的编译模型。

val  model =
  CompiledModel.create(
    context.assets,
    "mymodel.tflite"
  )

以下简化的代码段演示了 CompiledModel API 如何获取输入和输出缓冲区,以及如何使用已编译的模型运行推理。

// Preallocate input/output buffers
val inputBuffers = model.createInputBuffers()
val outputBuffers = model.createOutputBuffers()

// Fill the first input
inputBuffers[0].writeFloat(FloatArray(data_size) { data_value /* your data */ })
// Invoke
model.run(inputBuffers, outputBuffers)
// Read the output
val outputFloatArray = outputBuffers[0].readFloat()

// Clean up buffers and model
inputBuffers.forEach { it.close() }
outputBuffers.forEach { it.close() }
model.close()

如需更全面地了解 CompiledModel API 的实现方式,请参阅 Model.kt 中的源代码。

Tensor 缓冲区 (TensorBuffer)

LiteRT 使用 Tensor Buffer API (TensorBuffer) 处理进出 CompiledModel 的数据流,从而为 I/O 缓冲区互操作性提供内置支持。Tensor Buffer API 提供写入 (Write<T>())、读取 (Read<T>()) 和锁定缓冲区的功能。

如需更全面地了解 Tensor Buffer API 的实现方式,请参阅 TensorBuffer.kt 中的源代码。