將中繼資料新增至 TensorFlow Lite 模型

TensorFlow Lite 中繼資料提供了模型說明的標準,中繼資料是瞭解模型用途及其輸入 / 輸出資訊的重要知識來源。中繼資料包含

Kaggle Models 上發布的所有圖片模型,均已填入中繼資料。

中繼資料格式的模型

model_with_metadata
圖 1:具有中繼資料和相關檔案的 TFLite 模型。

模型中繼資料定義於 metadata_schema.fbs,這是 FlatBuffer 檔案。如圖 1 所示,這項資訊會儲存在 TFLite 模型結構定義metadata 欄位中,名稱下方為 "TFLITE_METADATA"。部分模型可能含有相關聯的檔案,例如分類標籤檔案。這些檔案會透過 ZipFile 的「附加」模式 ('a' 模式) 串連至原始模型檔案的結尾為 ZIP 檔案。TFLite 解譯器可以像往常一樣使用新的檔案格式。詳情請參閱「封裝相關檔案」。

請參閱下方操作說明,瞭解如何填入、視覺化及讀取中繼資料。

設定中繼資料工具

將中繼資料新增至模型之前,您必須先設定 Python 程式設計環境,才能執行 TensorFlow。如要進一步瞭解如何進行設定,請參閱這裡的詳細指南。

設定 Python 程式設計環境後,您需要安裝其他工具:

pip install tflite-support

TensorFlow Lite 中繼資料工具支援 Python 3。

使用 Flatbuffers Python API 新增中繼資料

結構定義中的模型中繼資料包含三個部分:

  1. 模型資訊:模型與授權條款等項目的整體說明。請參閱 ModelMetadata
    1. 輸入資訊:必要輸入內容和預先處理的說明,例如正規化。請參閱 SubGraphMetadata.input_tensor_metadata
      1. 輸出資訊:需要的輸出內容和後續處理說明,例如對應至標籤。請參閱 SubGraphMetadata.output_tensor_metadata

由於 TensorFlow Lite 目前只支援單一子圖表,因此 TensorFlow Lite 程式碼產生器Android Studio 機器學習繫結功能會在顯示中繼資料和產生程式碼時使用 ModelMetadata.nameModelMetadata.description,而非 SubGraphMetadata.nameSubGraphMetadata.description

支援的輸入 / 輸出類型

TensorFlow Lite 輸入和輸出內容的中繼資料並非針對特定模型類型設計,而是輸入和輸出類型。無論模型功能的運作方式為何,只要輸入和輸出類型包含下列或結合的組合,TensorFlow Lite 中繼資料就會支援這個模型:

  • 特徵 - 不是帶正負號整數或 float32 的數字。
  • 圖片 - 中繼資料目前支援 RGB 和灰階圖片。
  • 定界框 - 矩形定界框。結構定義支援多種編號配置

封裝相關聯的檔案

TensorFlow Lite 模型可能包含不同的相關檔案。舉例來說,自然語言模型通常會含有將字詞對應至字詞 ID 的詞彙檔案;分類模型可能會提供指出物件類別的標籤檔案。如果沒有關聯檔案 (如果有的話),模型將無法正常運作。

相關檔案現在可以透過中繼資料 Python 程式庫與模型一起封裝。新的 TensorFlow Lite 模型會成為 ZIP 檔案,當中包含模型和相關檔案。常見的壓縮工具可解壓縮。這個新模型格式會繼續使用相同的副檔名 .tflite。與現有的 TFLite 架構和解譯器相容。詳情請參閱「將中繼資料和關聯檔案封裝至模型」。

相關的檔案資訊可以記錄在中繼資料中。視檔案類型和附加位置 (例如 ModelMetadataSubGraphMetadataTensorMetadata) 而定,TensorFlow Lite Android 程式碼產生器可能會自動為物件套用對應的預先/後處理。詳情請參閱結構定義中每個關聯檔案類型的 <Codegen 用法> 一節

正規化和量化參數

正規化是機器學習的常見資料預先處理技術。正規化的目的是將值變更為共同的量表,而不會扭曲值範圍的差異。

模型量化這項技術可減少權重的精確度,並選擇是否要啟用儲存和運算功能。

在預先處理和後續處理方面,正規化和量化是兩個獨立的步驟。以下是問題的詳細說明:

正規化 量化

在 MobileNet 中,浮點模型和量式模型的輸入圖片參數值範例。
浮點模型
- 平均值:127.5
- std: 127.5
數量模型
- 平均值:127.5
- std: 127.5
浮點模型
- ZeroPoint: 0
- 比例:1.0
數量模型
- ZeroPoint: 128.0
- scale:0.0078125f




何時應叫用?


輸入內容:如果在訓練過程中將輸入資料正規化,就需要據此將推論的輸入資料正規化。
輸出:輸出資料一般不會標準化。
浮點模型不需要量化,
量化模型不一定需要在預先/後處理時量化。取決於輸入/輸出張量的資料類型。
- 浮點張量:不需要預先/後處理進行量化。數量運算和去量運算會烘焙到模型圖形中。
- int8/uint8 張量:需要預先/後處理量化。


公式


正規化輸入 = (輸入 - 平均值) / std
為輸入量化
q = f / scale + 0roPoint
輸出以量化的結果
f = (q - ZPoint) * scale

參數在哪裡?
由模型建立者填入並儲存在模型中繼資料中,做為 NormalizationOptions 由 TFLite 轉換工具自動填入,並儲存在 tflite 模型檔案中。
如何取得參數? 透過 MetadataExtractor API [2] 透過 TFLite Tensor API [1] 或透過 MetadataExtractor API。[2]
浮點數和量化模型是否共用相同的值? 是,浮點和量化模型的正規化參數 否,浮點模型不需要量化。
TFLite 程式碼產生器或 Android Studio 機器學習繫結會在資料處理中自動產生嗎?

[1] TensorFlow Lite Java APITensorFlow Lite C++ API
[2] 中繼資料擷取器資料庫

處理 uint8 模型的圖片資料時,系統有時會略過正規化和量化。當像素值在 [0, 255] 範圍內時即可避免。但一般而言,您應一律根據適用的正規化和量化參數來處理資料。

示例

您可以參考下列示例,瞭解如何為不同類型的模型填入中繼資料:

圖像分類

請到這裡下載指令碼,該指令碼會將中繼資料填入 mobilenet_v1_0.75_160_quantized.tflite。 按照以下方式執行指令碼:

python ./metadata_writer_for_image_classifier.py \
    --model_file=./model_without_metadata/mobilenet_v1_0.75_160_quantized.tflite \
    --label_file=./model_without_metadata/labels.txt \
    --export_directory=model_with_metadata

如要填入其他圖片分類模型的中繼資料,請將等模型規格加入指令碼中。本指南的其餘部分將重點介紹圖片分類範例中的幾個主要部分,以說明重要元素。

深入瞭解圖片分類範例

款式資訊

中繼資料的第一步是建立新的模型資訊:

from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb

""" ... """
"""Creates the metadata for an image classifier."""

# Creates model info.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "MobileNetV1 image classifier"
model_meta.description = ("Identify the most prominent object in the "
                          "image from a set of 1,001 categories such as "
                          "trees, animals, food, vehicles, person etc.")
model_meta.version = "v1"
model_meta.author = "TensorFlow"
model_meta.license = ("Apache License. Version 2.0 "
                      "http://www.apache.org/licenses/LICENSE-2.0.")

輸入 / 輸出資訊

本節說明如何說明模型的輸入和輸出簽章。自動程式碼產生器可能會使用這些中繼資料,建立預先和後續處理的程式碼。如何建立張量的相關輸入或輸出資訊:

# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()

圖片輸入

圖片是機器學習的常見輸入類型。TensorFlow Lite 中繼資料支援色彩空間和預先處理資訊 (例如正規化) 等資訊。影像的尺寸不需要手動規格,因為已經由輸入張量的形狀提供,並且可以自動推論。

input_meta.name = "image"
input_meta.description = (
    "Input image to be classified. The expected image is {0} x {1}, with "
    "three channels (red, blue, and green) per pixel. Each value in the "
    "tensor is a single byte between 0 and 255.".format(160, 160))
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
    _metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
    _metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [255]
input_stats.min = [0]
input_meta.stats = input_stats

標籤輸出

您可以使用 TENSOR_AXIS_LABELS,透過關聯檔案將標籤對應至輸出張量。

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()
output_meta.name = "probability"
output_meta.description = "Probabilities of the 1001 labels respectively."
output_meta.content = _metadata_fb.ContentT()
output_meta.content.content_properties = _metadata_fb.FeaturePropertiesT()
output_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_stats = _metadata_fb.StatsT()
output_stats.max = [1.0]
output_stats.min = [0.0]
output_meta.stats = output_stats
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename("your_path_to_label_file")
label_file.description = "Labels for objects that the model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
output_meta.associatedFiles = [label_file]

建立中繼資料 Flatbuffers

下列程式碼會合併模型資訊與輸入和輸出資訊:

# Creates subgraph info.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [output_meta]
model_meta.subgraphMetadata = [subgraph]

b = flatbuffers.Builder(0)
b.Finish(
    model_meta.Pack(b),
    _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()

將中繼資料和相關聯的檔案封裝至模型

建立中繼資料 Flatbuffers 後,中繼資料和標籤檔案會透過 populate 方法寫入 TFLite 檔案:

populator = _metadata.MetadataPopulator.with_model_file(model_file)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files(["your_path_to_label_file"])
populator.populate()

您可以透過 load_associated_files,將任意數量的相關檔案封裝至模型。不過,您至少需要封裝中繼資料中記錄的檔案。在此範例中,請務必封裝標籤檔案。

以視覺化方式呈現中繼資料

您可以使用 Netron 以視覺化方式呈現中繼資料,也可以使用 MetadataDisplayer 將 TensorFlow Lite 模型的中繼資料讀取為 JSON 格式:

displayer = _metadata.MetadataDisplayer.with_model_file(export_model_path)
export_json_file = os.path.join(FLAGS.export_directory,
                                os.path.splitext(model_basename)[0] + ".json")
json_file = displayer.get_metadata_json()
# Optional: write out the metadata as a json file
with open(export_json_file, "w") as f:
  f.write(json_file)

Android Studio 也支援透過 Android Studio 機器學習繫結功能顯示中繼資料。

中繼資料版本管理

中繼資料結構定義由語意版本編號共同管理,該號碼會追蹤結構定義檔案的變更,而 Flatbuffers 檔案識別,則是表示真實版本相容性。

語意版本編號

中繼資料結構定義是依「語意化編號」表示,例如 MAJOR.MINOR.PATCH。會根據這裡的規則追蹤結構定義異動。請參閱版本 1.0.0 後新增的欄位記錄

Flatbuffers 檔案識別

語意版本管理可以在遵循規則的情況下保證相容性,但並不代表真正不相容。提高 MAJOR 數字時,不一定代表回溯相容性損毀。因此,我們使用 Flatbuffers 檔案識別 file_identifier 來代表中繼資料結構定義的真正相容性。檔案 ID 的長度為 4 個字元。此值固定為特定中繼資料結構定義,且不會由使用者變更。如果中繼資料結構定義因某些原因而必須破壞回溯相容性,file_identifier 將會上升,例如從「M001」變更為「M002」。File_identifier 的變更頻率應遠低於 metadata_version。

中繼資料剖析器的最低版本需求

最低必要中繼資料剖析器版本是中繼資料剖析器 (Flatbuffers 產生的程式碼) 的最低版本,能完整讀取中繼資料 Flatbuffers。此版本實際上是在所有已填入欄位版本和檔案 ID 指定的最小相容版本之中最大版本號碼。在 TFLite 模型中填入中繼資料時,MetadataPopulator 會自動填入所需的最低中繼資料剖析器版本。如要進一步瞭解如何使用最低必要的中繼資料剖析器版本,請參閱中繼資料擷取器

讀取模型的中繼資料

中繼資料擷取器程式庫可讓您輕鬆地從不同平台中的模型讀取中繼資料和關聯檔案 (請參閱 Java 版本C++ 版本)。您可以使用 Flatbuffers 程式庫,以其他語言建構自己的中繼資料擷取工具。

以 Java 讀取中繼資料

如要在 Android 應用程式中使用中繼資料擷取器程式庫,建議您使用 MavenCentral 代管的 TensorFlow Lite 中繼資料 AAR。其中包含 MetadataExtractor 類別,以及中繼資料結構定義模型結構定義的 FlatBuffers Java 繫結。

您可以在 build.gradle 依附元件中指定此屬性,如下所示:

dependencies {
    implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0'
}

如要使用夜間快照,請確認您已新增 Sonatype 快照存放區

您可以使用指向模型的 ByteBuffer 來初始化 MetadataExtractor 物件:

public MetadataExtractor(ByteBuffer buffer);

MetadataExtractor 物件的整個生命週期中,ByteBuffer 必須保持不變。如果模型中繼資料的 Flatbuffers 檔案 ID 與中繼資料剖析器不符,初始化可能會失敗。詳情請參閱中繼資料版本管理一文。

使用相符的檔案 ID 時,由於 Flatbuffers 的轉送和回溯相容性機制,中繼資料擷取器將能成功讀取從過去和未來結構定義產生的中繼資料。不過,舊式中繼資料擷取器無法擷取未來結構定義中的欄位。中繼資料的最低剖析器版本代表可讀取完整中繼資料 Flatbuffer 的最低版本標準。您可以使用以下方法,確認是否符合最低所需的剖析器版本條件:

public final boolean isMinimumParserVersionSatisfied();

可以在沒有中繼資料的情況下傳入模型。不過,叫用讀取中繼資料的方法會導致執行階段錯誤。您可以叫用 hasMetadata 方法,檢查模型是否具有中繼資料:

public boolean hasMetadata();

MetadataExtractor 提供便利的函式,方便您取得輸入/輸出張量的中繼資料。舉例來說,

public int getInputTensorCount();
public TensorMetadata getInputTensorMetadata(int inputIndex);
public QuantizationParams getInputTensorQuantizationParams(int inputIndex);
public int[] getInputTensorShape(int inputIndex);
public int getoutputTensorCount();
public TensorMetadata getoutputTensorMetadata(int inputIndex);
public QuantizationParams getoutputTensorQuantizationParams(int inputIndex);
public int[] getoutputTensorShape(int inputIndex);

雖然 TensorFlow Lite 模型結構定義支援多個子圖表,但 TFLite 解譯器目前僅支援單一子圖表。因此,MetadataExtractor 會在其方法中將子圖表索引省略做為輸入引數。

讀取模型中的關聯檔案

含中繼資料和相關檔案的 TensorFlow Lite 模型本質上是可以與常見 ZIP 工具解壓縮的 ZIP 檔案,以便取得相關聯的檔案。舉例來說,您可以解壓縮 mobilenet_v1_0.75_160_quantized,並在模型中擷取標籤檔案,如下所示:

$ unzip mobilenet_v1_0.75_160_quantized_1_metadata_1.tflite
Archive:  mobilenet_v1_0.75_160_quantized_1_metadata_1.tflite
 extracting: labels.txt

您也可以透過中繼資料擷取器程式庫讀取相關聯的檔案。

在 Java 中,將檔案名稱傳入 MetadataExtractor.getAssociatedFile 方法:

public InputStream getAssociatedFile(String fileName);

同樣地,在 C++ 中,您可以利用 ModelMetadataExtractor::GetAssociatedFile 方法完成這項操作:

tflite::support::StatusOr<absl::string_view> GetAssociatedFile(
      const std::string& filename) const;