向 TensorFlow Lite 模型添加元数据

TensorFlow Lite 元数据提供了模型描述标准。元数据是了解模型的功能及其输入 / 输出信息的重要来源。元数据包含

Kaggle 模型上发布的所有图片模型都已填充元数据。

采用元数据格式的模型

model_with_metadata
图 1. 包含元数据和关联文件的 TFLite 模型。

模型元数据在 metadata_schema.fbs(一个 FlatBuffer 文件)中定义。如图 1 所示,它存储在 TFLite 模型架构metadata 字段中,名称为 "TFLITE_METADATA"。某些模型可能附带关联文件,例如分类标签文件。这些文件通过 ZipFile “附加”模式'a' 模式)以 ZIP 的形式串联到原始模型文件的末尾。TFLite 解释器可以像以前一样使用新文件格式。如需了解详情,请参阅打包关联文件

请参阅以下说明,了解如何填充、直观呈现和读取元数据。

设置元数据工具

在将元数据添加到模型之前,您需要具备运行 TensorFlow 的 Python 编程环境设置。有关如何设置此功能的详细指南,请参阅此处

设置 Python 编程环境后,您需要安装其他工具:

pip install tflite-support

TensorFlow Lite 元数据工具支持 Python 3。

使用 Flatbuffers Python API 添加元数据

架构中的模型元数据由三部分组成:

  1. 模型信息 - 模型的整体说明以及许可条款等内容。请参阅 ModelMetadata
    1. 输入信息 - 输入和所需预处理(例如归一化)的说明。请参阅 SubGraphMetadata.input_tensor_metadata
      1. 输出信息 - 所需的输出和后处理(例如映射到标签)的说明。请参阅 SubGraphMetadata.output_tensor_metadata

由于 TensorFlow Lite 目前仅支持单个子图,因此 TensorFlow Lite 代码生成器Android Studio ML 绑定功能将在显示元数据和生成代码时使用 ModelMetadata.nameModelMetadata.description,而不是 SubGraphMetadata.nameSubGraphMetadata.description

支持的输入 / 输出类型

用于输入和输出的 TensorFlow Lite 元数据在设计时未考虑特定的模型类型,而是在设计时考虑了输入和输出类型。无论模型在功能上执行什么操作,只要输入和输出类型包含以下各项或这些项的组合,就受 TensorFlow Lite 元数据支持:

  • 特征 - 无符号整数或 float32 的数字。
  • 图片 - 元数据目前支持 RGB 和灰度图片。
  • 边界框 - 矩形边界框。该架构支持各种编号方案

打包关联文件

TensorFlow Lite 模型可能附带不同的关联文件。例如,自然语言模型通常具有将词块映射到字词 ID 的词汇表文件;分类模型可能具有指示对象类别的标签文件。如果没有关联文件(如果有),模型将无法正常工作。

关联文件现在可以通过元数据 Python 库与模型捆绑在一起。新的 TensorFlow Lite 模型会成为一个同时包含模型和关联文件的 ZIP 文件。可以使用常见的 ZIP 工具解压缩。这种新的模型格式会保持相同的文件扩展名 .tflite。它与现有 TFLite 框架和解释器兼容。如需了解详情,请参阅将元数据和关联文件打包到模型中

关联的文件信息可以记录在元数据中。根据文件类型和附加文件的位置(即 ModelMetadataSubGraphMetadataTensorMetadata),TensorFlow Lite Android 代码生成器可能会自动对对象应用相应的预处理/后处理。如需了解详情,请参阅架构中每个关联文件类型的 <Codegen usage> 部分

归一化和量化参数

归一化是机器学习中常用的数据预处理技术。归一化的目的是将值更改为通用比例,而不改变值范围的差异。

模型量化是一种技术,可降低权重的精度表示,还可选择性地为存储和计算启用激活。

在预处理和后处理方面,归一化和量化是两个独立的步骤。详情如下。

规范化 量化

分别是 MobileNet 中浮点模型和量化模型的输入图片的参数值示例。
浮点模型
- 平均值:127.5
- std:127.5
量化模型
- 平均值:127.5
- 标准值:127.5
浮点模型
-zeroPoint: 0
- scale:1.0
量化模型
-zeroPoint:128.0
- scale:0.0078125f




何时调用?


输入:如果输入数据在训练中进行了归一化,则推理的输入数据也需要相应地归一化。
输出:输出数据通常不会进行标准化。
浮点模型不需要量化。
量化模型不一定需要在预处理/后处理过程中进行量化。这取决于输入/输出张量的数据类型。
- 浮点张量:在处理前/后无需量化。定量运算和去量化运算内置于模型图中。
- int8/uint8 张量:在前/后处理过程中需要量化。


公式


normalized_input = (输入 - 平均值) / std
输入量化
q = f / scale + zeroPoint
输出去量化
f = (q -zeroPoint) * scale

参数在哪里
由模型创建者填充并存储在模型元数据中,如 NormalizationOptions 由 TFLite 转换器自动填充并存储在 tflite 模型文件中。
如何获取参数? 通过 MetadataExtractor API [2] 通过 TFLite Tensor API [1] 或 MetadataExtractor API [2]
浮点模型和量化模型的值是否相同? 是的,浮点模型和量化模型具有相同的归一化参数 不需要,浮点模型不需要量化。
TFLite 代码生成器或 Android Studio 机器学习绑定是否会在数据处理中自动生成它?

[1] TensorFlow Lite Java APITensorFlow Lite C++ API
[2] 元数据提取器库

在处理 uint8 模型的图片数据时,有时会跳过归一化和量化。当像素值在 [0, 255] 范围内时,就允许这样做。但一般来说,您应该始终根据归一化和量化参数(如果适用)处理数据。

示例

您可以点击以下链接,找到有关如何为不同类型的模型填充元数据的示例:

图像分类

此处下载脚本,以便将元数据填充到 mobilenet_v1_0.75_160_quantized.tflite。按如下所示运行脚本:

python ./metadata_writer_for_image_classifier.py \
    --model_file=./model_without_metadata/mobilenet_v1_0.75_160_quantized.tflite \
    --label_file=./model_without_metadata/labels.txt \
    --export_directory=model_with_metadata

如需填充其他图片分类模型的元数据,请将此处所示的模型规范添加到脚本中。本指南的其余部分将重点介绍图片分类示例中的一些关键部分,以说明关键元素。

深入了解图片分类示例

模型信息

元数据首先创建新模型信息:

from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb

""" ... """
"""Creates the metadata for an image classifier."""

# Creates model info.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "MobileNetV1 image classifier"
model_meta.description = ("Identify the most prominent object in the "
                          "image from a set of 1,001 categories such as "
                          "trees, animals, food, vehicles, person etc.")
model_meta.version = "v1"
model_meta.author = "TensorFlow"
model_meta.license = ("Apache License. Version 2.0 "
                      "http://www.apache.org/licenses/LICENSE-2.0.")

输入 / 输出信息

本部分介绍如何描述模型的输入和输出签名。自动代码生成器可以使用这些元数据来创建处理前和处理后的代码。如需创建有关张量的输入或输出信息,请执行以下操作:

# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()

图片输入

图片是机器学习的常见输入类型。TensorFlow Lite 元数据支持色彩空间等信息和预处理信息(如归一化)。图像的尺寸不需要手动指定,因为图像的维度已经由输入张量的形状提供,并且可以自动推断出来。

input_meta.name = "image"
input_meta.description = (
    "Input image to be classified. The expected image is {0} x {1}, with "
    "three channels (red, blue, and green) per pixel. Each value in the "
    "tensor is a single byte between 0 and 255.".format(160, 160))
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
    _metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
    _metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [255]
input_stats.min = [0]
input_meta.stats = input_stats

标签输出

可以使用 TENSOR_AXIS_LABELS 通过关联文件将标签映射到输出张量。

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()
output_meta.name = "probability"
output_meta.description = "Probabilities of the 1001 labels respectively."
output_meta.content = _metadata_fb.ContentT()
output_meta.content.content_properties = _metadata_fb.FeaturePropertiesT()
output_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_stats = _metadata_fb.StatsT()
output_stats.max = [1.0]
output_stats.min = [0.0]
output_meta.stats = output_stats
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename("your_path_to_label_file")
label_file.description = "Labels for objects that the model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
output_meta.associatedFiles = [label_file]

创建元数据 Flatbuffers

以下代码将模型信息与输入和输出信息相结合:

# Creates subgraph info.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [output_meta]
model_meta.subgraphMetadata = [subgraph]

b = flatbuffers.Builder(0)
b.Finish(
    model_meta.Pack(b),
    _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()

将元数据和关联文件打包到模型中

创建元数据 Flatbuffer 之后,系统会通过 populate 方法将元数据和标签文件写入 TFLite 文件:

populator = _metadata.MetadataPopulator.with_model_file(model_file)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files(["your_path_to_label_file"])
populator.populate()

您可以通过 load_associated_files 将任意数量的关联文件打包到模型中。不过,您必须至少打包元数据中记录的文件。在此示例中,必须打包标签文件。

直观呈现元数据

您可以使用 Netron 直观呈现元数据,也可以使用 MetadataDisplayer 将 TensorFlow Lite 模型中的元数据读取为 json 格式:

displayer = _metadata.MetadataDisplayer.with_model_file(export_model_path)
export_json_file = os.path.join(FLAGS.export_directory,
                                os.path.splitext(model_basename)[0] + ".json")
json_file = displayer.get_metadata_json()
# Optional: write out the metadata as a json file
with open(export_json_file, "w") as f:
  f.write(json_file)

Android Studio 还支持通过 Android Studio 机器学习绑定功能显示元数据。

元数据版本控制

元数据架构的版本号由语义版本号(用于跟踪架构文件更改)和 Flatbuffers 文件标识(指示真正的版本兼容性)进行版本控制。

语义版本号

元数据架构使用语义版本号进行版本控制,例如 MAJOR.MINOR.PATCH。它根据此处的规则跟踪架构更改。请参阅在 1.0.0 之后添加的字段历史记录

Flatbuffers 文件标识

如果遵循这些规则,语义版本控制可保证兼容性,但并不暗示真正的不兼容性。提高 MAJOR 数并不一定意味着向后兼容性被破坏。因此,我们使用 Flatbuffer 文件标识 (file_identifier) 来表示元数据架构的真正兼容性。文件标识符的长度恰好为 4 个字符。它采用特定的元数据架构,不会受用户更改。如果元数据架构的向后兼容性因某种原因而必须破坏,则 file_identifier 会递增,例如从“M001”更改为“M002”。File_identifier 的更改频率预计会比 metadata_version 低很多。

所需的最低元数据解析器版本

所需的最低元数据解析器版本是可以完整读取元数据 Flatbuffer 的元数据解析器(Flatbuffer 生成的代码)的最低版本。该版本实际上是所有填充字段的版本中的最大版本号,也是文件标识符指示的最小兼容版本。当元数据填充到 TFLite 模型中时,MetadataPopulator 会自动填充所需的最低元数据解析器版本。如需详细了解如何使用所需的最低元数据解析器版本,请参阅元数据提取器

从模型中读取元数据

Metadata Extractor 库是跨多个平台从模型中读取元数据和关联文件的便捷工具(请参阅 Java 版本C++ 版本)。您可以使用 Flatbuffers 库以其他语言构建自己的元数据提取器工具。

读取 Java 中的元数据

如需在 Android 应用中使用 Metadata Extractor 库,我们建议您使用 MavenCentral 托管的 TensorFlow Lite 元数据 AAR。它包含 MetadataExtractor 类,以及元数据架构模型架构的 FlatBuffers Java 绑定。

您可以在 build.gradle 依赖项中进行指定,如下所示:

dependencies {
    implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0'
}

如需使用夜间快照,请确保已添加 Sonatype 快照代码库

您可以使用指向模型的 ByteBuffer 初始化 MetadataExtractor 对象:

public MetadataExtractor(ByteBuffer buffer);

ByteBuffer 必须在 MetadataExtractor 对象的整个生命周期内保持不变。如果模型元数据的 Flatbuffers 文件标识符与元数据解析器的标识符不匹配,则初始化可能会失败。如需了解详情,请参阅元数据版本控制

借助匹配的文件标识符,由于 Flatbuffer 的向前和向后兼容性机制,元数据提取器将成功读取从过去和未来的所有架构生成的元数据。但是,较旧的元数据提取器无法提取未来架构中的字段。元数据的必要最低解析器版本表示可以读取完整元数据 Flatbuffer 的元数据解析器的最低版本。您可以使用以下方法验证是否满足所需的最低解析器版本条件:

public final boolean isMinimumParserVersionSatisfied();

允许传入没有元数据的模型。但是,调用从元数据读取的方法会导致运行时错误。您可以通过调用 hasMetadata 方法来检查模型是否包含元数据:

public boolean hasMetadata();

MetadataExtractor 提供了一些便捷函数,用于获取输入/输出张量的元数据。例如,

public int getInputTensorCount();
public TensorMetadata getInputTensorMetadata(int inputIndex);
public QuantizationParams getInputTensorQuantizationParams(int inputIndex);
public int[] getInputTensorShape(int inputIndex);
public int getoutputTensorCount();
public TensorMetadata getoutputTensorMetadata(int inputIndex);
public QuantizationParams getoutputTensorQuantizationParams(int inputIndex);
public int[] getoutputTensorShape(int inputIndex);

虽然 TensorFlow Lite 模型架构支持多个子图,但 TFLite 解释器目前仅支持单个子图。因此,MetadataExtractor 在其方法中省略了子图索引作为输入参数。

从模型中读取关联文件

包含元数据和关联文件的 TensorFlow Lite 模型本质上是一个 ZIP 文件,可以使用常见的 ZIP 工具解压缩以获取关联文件。例如,您可以解压缩 mobilenet_v1_0.75_160_quantized 并提取模型中的标签文件,如下所示:

$ unzip mobilenet_v1_0.75_160_quantized_1_metadata_1.tflite
Archive:  mobilenet_v1_0.75_160_quantized_1_metadata_1.tflite
 extracting: labels.txt

您还可以通过 Metadata Extractor 库读取关联文件。

在 Java 中,将文件名传入 MetadataExtractor.getAssociatedFile 方法:

public InputStream getAssociatedFile(String fileName);

同样,在 C++ 中,也可使用 ModelMetadataExtractor::GetAssociatedFile 方法完成此操作:

tflite::support::StatusOr<absl::string_view> GetAssociatedFile(
      const std::string& filename) const;