适用于 Android 的文本分类指南

借助 MediaPipe 文本分类器任务,您可以将文本归类到一组已定义的类别中,例如正面或负面的情感。类别决定了您使用的模型以及模型的训练方式。以下说明介绍了如何在 Android 应用中使用文本分类器。

您可以观看演示,了解此任务的实际运用。如需详细了解此任务的功能、模型和配置选项,请参阅概览

代码示例

文本分类器的示例代码提供了此任务的简单实现,供您参考。此代码可帮助您测试此任务并开始构建自己的文本分类应用。您可以在 GitHub 上浏览文本分类器示例代码

下载代码

以下说明介绍了如何使用 git 版本控制命令行工具创建示例代码的本地副本。

如需下载示例代码,请执行以下操作:

  1. 使用以下命令克隆 Git 代码库:
    git clone https://github.com/google-ai-edge/mediapipe-samples
    
  2. (可选)将您的 Git 实例配置为使用稀疏检出,以便只有文本分类器示例应用的文件:
    cd mediapipe
    git sparse-checkout init --cone
    git sparse-checkout set examples/text_classification/android
    

如需了解如何使用 Android Studio 设置和运行示例,请参阅 Android 设置指南中的示例代码设置说明。

关键组件

以下文件包含文本分类示例应用的关键代码:

初始设置

本部分介绍了专门为使用文本分类器而设置开发环境和代码项目的关键步骤。如需了解如何为使用 MediaPipe Tasks 设置开发环境的常规信息(包括平台版本要求),请参阅 Android 设置指南

依赖项

文本分类器使用 com.google.mediapipe:tasks-text 库。将此依赖项添加到 Android 应用开发项目的 build.gradle 文件中。您可以使用以下代码导入所需的依赖项:

dependencies {
    implementation 'com.google.mediapipe:tasks-text:latest.release'
}

模型

MediaPipe 文本分类器任务需要使用与此任务兼容的经过训练的模型。如需详细了解文本分类器可用的经过训练的模型,请参阅任务概览“模型”部分

选择并下载模型,然后将其存储在项目的 assets 目录中:

<dev-project-root>/src/main/assets

使用 BaseOptions.Builder.setModelAssetPath() 方法指定要使用的模型的路径。如需查看代码示例,请参阅下一部分。

创建任务

使用某个文本分类器 TextClassifier.createFrom...() 函数来准备运行推断的任务。您可以将 createFromFile() 函数与训练后的模型文件的相对或绝对路径搭配使用。以下代码示例演示了如何使用 TextClassifier.createFromOptions() 函数。如需详细了解可用的配置选项,请参阅配置选项

以下代码演示了如何构建和配置此任务。

// no directory path required if model file is in src/main/assets:
String currentModel = "text_classifier_model.tflite";

fun initClassifier() {
    val baseOptionsBuilder = BaseOptions.builder()
        .setModelAssetPath(currentModel)
    try {
        val baseOptions = baseOptionsBuilder.build()
        val optionsBuilder = TextClassifier.TextClassifierOptions.builder()
            .setBaseOptions(baseOptions)
        val options = optionsBuilder.build()
        textClassifier = TextClassifier.createFromOptions(context, options)
    } catch (e: IllegalStateException) { // exception handling
    }
}

您可以在代码示例 TextClassifierHelperinitClassifier() 函数中查看创建任务的示例。

配置选项

此任务具有以下 Android 应用的配置选项:

选项名称 说明 值范围 默认值
displayNamesLocale 设置任务模型元数据中提供的显示名(如果有)要使用的标签语言。英语的默认值为 en。您可以使用 TensorFlow Lite Metadata Writer API 向自定义模型的元数据添加本地化标签。语言区域代码 en
maxResults 设置要返回的得分最高的分类结果的数量上限(可选)。如果小于 0,将返回所有可用的结果。 任何正数 -1
scoreThreshold 设置预测分数阈值,以替换模型元数据中提供的阈值(如果有)。低于此值的结果会被拒绝。 任意浮点数 未设置
categoryAllowlist 设置允许的类别名称的可选列表。如果为非空,则类别名称不在此集合中的分类结果将被滤除。系统会忽略重复或未知的类别名称。 此选项与 categoryDenylist 互斥,如果同时使用这两者,就会引发错误。 任何字符串 未设置
categoryDenylist 设置不允许使用的类别名称的可选列表。如果非空,则类别名称在此集合中的分类结果将被滤除。系统会忽略重复或未知的类别名称。此选项与 categoryAllowlist 互斥,同时使用这两者会导致错误。 任何字符串 未设置

准备数据

文本分类器可处理文本 (String) 数据。该任务会处理数据输入预处理,包括标记化和张量预处理。

所有预处理都在 classify() 函数中处理。无需事先对输入文本进行额外的预处理。

String inputText = "The input text to be classified.";

运行任务

文本分类器使用 TextClassifier.classify() 函数运行推断。请使用单独的执行线程来执行分类,以免您的应用阻塞 Android 界面线程。

以下代码演示了如何使用单独的执行线程通过任务模型执行处理。

    fun classify(text: String) {
        executor = ScheduledThreadPoolExecutor(1)

        executor.execute {
            val results = textClassifier.classify(text)
            listener.onResult(results)
        }
    }

您可以在代码示例 TextClassifierHelperclassify() 函数中查看有关如何运行任务的示例。

处理和显示结果

文本分类器会输出 TextClassifierResult,其中包含输入文本的可能类别列表。类别由您使用的模型定义,因此,如果您想要不同的类别,请选择其他模型或重新训练现有模型。

下面显示了此任务的输出数据示例:

TextClassificationResult:
  Classification #0 (single classification head):
    ClassificationEntry #0:
      Category #0:
        category name: "positive"
        score: 0.8904
        index: 0
      Category #1:
        category name: "negative"
        score: 0.1096
        index: 1

通过对输入文本运行 BERT-classifier 来获取此结果:"an imperfect but overall entertaining mystery"

您可以查看代码示例 ResultsAdapter 类和 ViewHolder 内部类,查看如何显示结果的示例。