マイクロコントローラのスタートガイド

このドキュメントでは、モデルをトレーニングし、マイクロコントローラを使用して推論を実行する方法について説明します。

Hello World の例

Hello World サンプルは、TensorFlow Lite for Microcontrollers の基本的な使い方を説明することを目的としています。正弦関数を複製するモデルをトレーニングして実行します。つまり、1 つの数値を入力として受け取り、数値の正弦を出力します。マイクロコントローラにデプロイすると、その予測を使用して LED の点滅やアニメーションの制御が行われます。

エンドツーエンドのワークフローには次の手順が含まれます。

  1. モデルをトレーニングする(Python): デバイスで使用するためにモデルをトレーニング、変換、最適化するための Python ファイル。
  2. 推論の実行(C++ 17): C++ ライブラリを使用してモデルに対して推論を実行するエンドツーエンドの単体テスト。

サポートされているデバイスを入手する

ここで使用するサンプル アプリケーションは、次のデバイスでテスト済みです。

サポートされているプラットフォームについては、TensorFlow Lite for Microcontrollers をご覧ください。

モデルのトレーニング

正弦波認識の Hello World モデル トレーニングに train.py を使用する

次のコマンドを実行します。bazel build tensorflow/lite/micro/examples/hello_world:train bazel-bin/tensorflow/lite/micro/examples/hello_world/train --save_tf_model --save_dir=/tmp/model_created/

推論を実行する

デバイスでモデルを実行するために、README.md にある手順を説明します。

Hello World README.md

以下のセクションでは、サンプルの evaluate_test.cc という単体テストについて説明します。これは、TensorFlow Lite for Microcontrollers を使用して推論を実行する方法を示しています。モデルを読み込み、推論を複数回実行します。

1. ライブラリ ヘッダーを含める

TensorFlow Lite for Microcontrollers ライブラリを使用するには、次のヘッダー ファイルを含める必要があります。

#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
  • micro_mutable_op_resolver.h は、モデルを実行するためにインタープリタが使用するオペレーションを提供します。
  • micro_error_reporter.h はデバッグ情報を出力します。
  • micro_interpreter.h には、モデルを読み込んで実行するためのコードが含まれています。
  • schema_generated.h には、TensorFlow Lite の FlatBuffer モデルファイル形式のスキーマが含まれています。
  • version.h は、TensorFlow Lite スキーマのバージョニング情報を提供します。

2. モデルヘッダーを含める

TensorFlow Lite for Microcontrollers インタープリタは、モデルが C++ 配列として提供されることを想定しています。このモデルは、model.h ファイルと model.cc ファイルで定義されます。ヘッダーは次の行に含まれています。

#include "tensorflow/lite/micro/examples/hello_world/model.h"

3. 単体テスト フレームワークのヘッダーを含める

単体テストを作成するために、次の行を追加して TensorFlow Lite for Microcontrollers 単体テスト フレームワークをインクルードします。

#include "tensorflow/lite/micro/testing/micro_test.h"

テストは次のマクロを使用して定義されます。

TF_LITE_MICRO_TESTS_BEGIN

TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
  . // add code here
  .
}

TF_LITE_MICRO_TESTS_END

次に、上記のマクロに含まれるコードについて説明します。

4. ロギングの設定

ロギングを設定するために、tflite::MicroErrorReporter インスタンスへのポインタを使用して tflite::ErrorReporter ポインタが作成されます。

tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = &micro_error_reporter;

この変数はインタープリタに渡され、インタープリタがログを書き込めるようになります。多くの場合、マイクロコントローラにはロギングのためのさまざまなメカニズムがあるため、tflite::MicroErrorReporter の実装は特定のデバイスに合わせてカスタマイズされるように設計されています。

5. モデルを読み込む

次のコードでは、model.h で宣言されている char 配列 g_model のデータを使用してモデルをインスタンス化しています。次に、モデルをチェックして、スキーマ バージョンが使用しているバージョンと互換性があることを確認します。

const tflite::Model* model = ::tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
  TF_LITE_REPORT_ERROR(error_reporter,
      "Model provided is schema version %d not equal "
      "to supported version %d.\n",
      model->version(), TFLITE_SCHEMA_VERSION);
}

6. オペレーション リゾルバをインスタンス化する

MicroMutableOpResolver インスタンスが宣言されている。これは、モデルで使用されるオペレーションを登録してアクセスするために、インタープリタによって使用されます。

using HelloWorldOpResolver = tflite::MicroMutableOpResolver<1>;

TfLiteStatus RegisterOps(HelloWorldOpResolver& op_resolver) {
  TF_LITE_ENSURE_STATUS(op_resolver.AddFullyConnected());
  return kTfLiteOk;

MicroMutableOpResolver には、登録されるオペレーションの数を示すテンプレート パラメータが必要です。RegisterOps 関数は、演算をリゾルバに登録します。

HelloWorldOpResolver op_resolver;
TF_LITE_ENSURE_STATUS(RegisterOps(op_resolver));

7.メモリを割り当てる

入力配列、出力配列、中間配列に一定量のメモリを事前に割り当てる必要があります。これは、サイズ tensor_arena_sizeuint8_t 配列として提供されます。

const int tensor_arena_size = 2 * 1024;
uint8_t tensor_arena[tensor_arena_size];

必要なサイズは使用しているモデルによって異なります。また、テストによって決定しなければならない場合もあります。

8. インタープリタのインスタンス化

tflite::MicroInterpreter インスタンスを作成し、先ほど作成した変数を渡します。

tflite::MicroInterpreter interpreter(model, resolver, tensor_arena,
                                     tensor_arena_size, error_reporter);

9.テンソルを割り当てる

モデルのテンソルに tensor_arena からメモリを割り当てるようにインタープリタに指示します。

interpreter.AllocateTensors();

10. 入力シェイプを検証する

MicroInterpreter インスタンスは、.input(0) を呼び出すことで、モデルの入力テンソルへのポインタを提供できます。ここで、0 は最初(かつ唯一の)入力テンソルを表します。

  // Obtain a pointer to the model's input tensor
  TfLiteTensor* input = interpreter.input(0);

次に、このテンソルを調べて、その形状と型が期待どおりであることを確認します。

// Make sure the input has the properties we expect
TF_LITE_MICRO_EXPECT_NE(nullptr, input);
// The property "dims" tells us the tensor's shape. It has one element for
// each dimension. Our input is a 2D tensor containing 1 element, so "dims"
// should have size 2.
TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size);
// The value of each element gives the length of the corresponding tensor.
// We should expect two single element tensors (one is contained within the
// other).
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
// The input is a 32 bit floating point value
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, input->type);

列挙値 kTfLiteFloat32 は、TensorFlow Lite データ型のいずれかへの参照で、common.h で定義されています。

11. 入力値を指定する

モデルに入力を提供するには、入力テンソルの内容を次のように設定します。

input->data.f[0] = 0.;

ここでは、0 を表す浮動小数点値を入力します。

12. モデルを実行する

モデルを実行するには、tflite::MicroInterpreter インスタンスで Invoke() を呼び出します。

TfLiteStatus invoke_status = interpreter.Invoke();
if (invoke_status != kTfLiteOk) {
  TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed\n");
}

戻り値(TfLiteStatus)を確認して、実行が成功したかどうかを判断できます。common.h で定義されている TfLiteStatus の有効な値は kTfLiteOkkTfLiteError です。

次のコードは、値が kTfLiteOk であること、つまり推論が正常に実行されたことをアサートします。

TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);

13. 出力を取得する

モデルの出力テンソルは、tflite::MicroInterpreter に対して output(0) を呼び出すことで取得できます。ここで、0 は最初(かつ唯一の)出力テンソルを表します。

この例では、モデルの出力は 2D テンソル内に含まれる単一の浮動小数点値です。

TfLiteTensor* output = interpreter.output(0);
TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, output->type);

出力テンソルから値を直接読み取り、期待どおりであることをアサートできます。

// Obtain the output value from the tensor
float value = output->data.f[0];
// Check that the output value is within 0.05 of the expected value
TF_LITE_MICRO_EXPECT_NEAR(0., value, 0.05);

14. 推論を再実行する

コードの残りの部分では、推論をさらに数回実行します。各インスタンスで、入力テンソルに値を割り当て、インタープリタを呼び出し、出力テンソルから結果を読み取ります。

input->data.f[0] = 1.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.841, value, 0.05);

input->data.f[0] = 3.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.141, value, 0.05);

input->data.f[0] = 5.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(-0.959, value, 0.05);