使用 TensorFlow Lite 支援資料庫處理輸入和輸出資料

行動應用程式開發人員通常會與輸入的物件互動,例如點陣圖或原始物件 (例如整數)。不過,在裝置端執行機器學習模型的 TensorFlow Lite 解譯器 API 使用張量 (ByteBuffer) 形式,因此難以偵錯及操控。TensorFlow Lite Android 支援資料庫旨在協助處理 TensorFlow Lite 模型的輸入和輸出內容,並讓 TensorFlow Lite 解譯器更容易使用。

開始使用

匯入 Gradle 依附元件和其他設定

.tflite 模型檔案複製到要執行模型的 Android 模組資產目錄。指定不應壓縮的檔案,並將 TensorFlow Lite 程式庫新增至模組的 build.gradle 檔案:

android {
    // Other settings

    // Specify tflite file should not be compressed for the app apk
    aaptOptions {
        noCompress "tflite"
    }

}

dependencies {
    // Other dependencies

    // Import tflite dependencies
    implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT'
    // The GPU delegate library is optional. Depend on it as needed.
    implementation 'org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly-SNAPSHOT'
    implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly-SNAPSHOT'
}

探索在 MavenCentral 託管的 TensorFlow Lite 支援資料庫 AAR,瞭解不同版本的支援資料庫。

基本圖片操作與轉換

TensorFlow Lite 支援資料庫提供一組基本的圖片操控方法,例如裁剪和調整大小。如要使用,請建立 ImagePreprocessor 並新增必要的作業。如要將圖片轉換為 TensorFlow Lite 解譯器所需的 Tensor 格式,請建立 TensorImage 做為輸入使用:

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.image.ops.ResizeOp;

// Initialization code
// Create an ImageProcessor with all ops required. For more ops, please
// refer to the ImageProcessor Architecture section in this README.
ImageProcessor imageProcessor =
    new ImageProcessor.Builder()
        .add(new ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR))
        .build();

// Create a TensorImage object. This creates the tensor of the corresponding
// tensor type (uint8 in this case) that the TensorFlow Lite interpreter needs.
TensorImage tensorImage = new TensorImage(DataType.UINT8);

// Analysis code for every frame
// Preprocess the image
tensorImage.load(bitmap);
tensorImage = imageProcessor.process(tensorImage);

張量的 DataType 可透過中繼資料擷取器程式庫,以及其他模型資訊讀取。

基本音訊資料處理

TensorFlow Lite 支援資料庫也會定義 TensorAudio 類別,用於包裝部分基本音訊資料處理方法。且大多與 AudioRecord 搭配使用,並擷取環形緩衝區的音訊樣本。

import android.media.AudioRecord;
import org.tensorflow.lite.support.audio.TensorAudio;

// Create an `AudioRecord` instance.
AudioRecord record = AudioRecord(...)

// Create a `TensorAudio` object from Android AudioFormat.
TensorAudio tensorAudio = new TensorAudio(record.getFormat(), size)

// Load all audio samples available in the AudioRecord without blocking.
tensorAudio.load(record)

// Get the `TensorBuffer` for inference.
TensorBuffer buffer = tensorAudio.getTensorBuffer()

建立輸出物件並執行模型

執行模型前,需要建立用來儲存結果的容器物件:

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

// Create a container for the result and specify that this is a quantized model.
// Hence, the 'DataType' is defined as UINT8 (8-bit unsigned integer)
TensorBuffer probabilityBuffer =
    TensorBuffer.createFixedSize(new int[]{1, 1001}, DataType.UINT8);

載入模型並執行推論:

import java.nio.MappedByteBuffer;
import org.tensorflow.lite.InterpreterFactory;
import org.tensorflow.lite.InterpreterApi;

// Initialise the model
try{
    MappedByteBuffer tfliteModel
        = FileUtil.loadMappedFile(activity,
            "mobilenet_v1_1.0_224_quant.tflite");
    InterpreterApi tflite = new InterpreterFactory().create(
        tfliteModel, new InterpreterApi.Options());
} catch (IOException e){
    Log.e("tfliteSupport", "Error reading model", e);
}

// Running inference
if(null != tflite) {
    tflite.run(tImage.getBuffer(), probabilityBuffer.getBuffer());
}

存取結果

開發人員可以直接透過 probabilityBuffer.getFloatArray() 存取輸出內容。如果模型產生量化輸出,請記得轉換結果。針對 MobileNet 量化模型,開發人員需要將每個輸出值除以 255,才能取得每個類別 0 (最可能) 到 1 (最可能) 的機率範圍。

選用:將結果對應至標籤

開發人員也可以選擇將結果對應至標籤。首先,請將包含標籤的文字檔案複製到模組的資產目錄中。接著,使用下列程式碼載入標籤檔案:

import org.tensorflow.lite.support.common.FileUtil;

final String ASSOCIATED_AXIS_LABELS = "labels.txt";
List<String> associatedAxisLabels = null;

try {
    associatedAxisLabels = FileUtil.loadLabels(this, ASSOCIATED_AXIS_LABELS);
} catch (IOException e) {
    Log.e("tfliteSupport", "Error reading label file", e);
}

下列程式碼片段示範如何將機率與類別標籤建立關聯:

import java.util.Map;
import org.tensorflow.lite.support.common.TensorProcessor;
import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.label.TensorLabel;

// Post-processor which dequantize the result
TensorProcessor probabilityProcessor =
    new TensorProcessor.Builder().add(new NormalizeOp(0, 255)).build();

if (null != associatedAxisLabels) {
    // Map of labels and their corresponding probability
    TensorLabel labels = new TensorLabel(associatedAxisLabels,
        probabilityProcessor.process(probabilityBuffer));

    // Create a map to access the result based on label
    Map<String, Float> floatMap = labels.getMapWithFloatValue();
}

目前應用實例

目前版本的 TensorFlow Lite 支援資料庫涵蓋:

  • 常見資料類型 (浮點值、uint8、圖片、音訊和陣列),做為 tflite 模型的輸入和輸出。
  • 基本的圖片操作 (裁剪圖片、調整大小及旋轉)。
  • 正規化與量化
  • 檔案公用程式

日後推出的版本將改善文字相關應用程式的支援。

影像處理器架構

ImageProcessor 的設計允許在建構程序期間預先定義及最佳化圖片操縱作業。ImageProcessor 目前支援三種基本預先處理作業,如下方程式碼片段中的三個註解所示:

import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.common.ops.QuantizeOp;
import org.tensorflow.lite.support.image.ops.ResizeOp;
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp;
import org.tensorflow.lite.support.image.ops.Rot90Op;

int width = bitmap.getWidth();
int height = bitmap.getHeight();

int size = height > width ? width : height;

ImageProcessor imageProcessor =
    new ImageProcessor.Builder()
        // Center crop the image to the largest square possible
        .add(new ResizeWithCropOrPadOp(size, size))
        // Resize using Bilinear or Nearest neighbour
        .add(new ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR));
        // Rotation counter-clockwise in 90 degree increments
        .add(new Rot90Op(rotateDegrees / 90))
        .add(new NormalizeOp(127.5, 127.5))
        .add(new QuantizeOp(128.0, 1/128.0))
        .build();

如要進一步瞭解正規化和量化,請參閱這篇文章

支援資料庫的最終目標是支援所有 tf.image 轉換。這意味著轉換作業和 TensorFlow 相同,實作作業也將獨立於作業系統之外。

開發人員也可以建立自訂處理器。在這些情況下,請務必與訓練程序保持一致。也就是說,相同的預先處理作業應同時套用至訓練和推論,才能增加可重現性。

量化

啟動輸入或輸出物件 (例如 TensorImageTensorBuffer) 時,您需要將物件類型指定為 DataType.UINT8DataType.FLOAT32

TensorImage tensorImage = new TensorImage(DataType.UINT8);
TensorBuffer probabilityBuffer =
    TensorBuffer.createFixedSize(new int[]{1, 1001}, DataType.UINT8);

TensorProcessor 可用來量化輸入張量,或將輸出張量反化。舉例來說,在處理量化輸出 TensorBuffer 時,開發人員可以使用 DequantizeOp 將結果量化為介於 0 至 1 之間的浮點機率:

import org.tensorflow.lite.support.common.TensorProcessor;

// Post-processor which dequantize the result
TensorProcessor probabilityProcessor =
    new TensorProcessor.Builder().add(new DequantizeOp(0, 1/255.0)).build();
TensorBuffer dequantizedBuffer = probabilityProcessor.process(probabilityBuffer);

Tensor 的量化參數可透過中繼資料擷取器程式庫讀取。