Dokumen ini menjelaskan cara melatih model dan menjalankan inferensi menggunakan pengontrol mikro.
Contoh Halo Dunia
Contoh Hello World dirancang untuk mendemonstrasikan dasar-dasar penggunaan TensorFlow Lite untuk Mikrokontroler. Kita melatih dan menjalankan model yang mereplikasi fungsi sinus, yaitu model tersebut menggunakan satu angka sebagai inputnya dan menghasilkan nilai sinus angka. Saat di-deploy ke pengontrol mikro, prediksinya digunakan untuk berkedip LED atau mengontrol animasi.
Alur kerja end-to-end melibatkan langkah-langkah berikut:
- Latih model (di Python): File Python untuk melatih, mengonversi, dan mengoptimalkan model untuk penggunaan di perangkat.
- Menjalankan inferensi (di C++ 17): Pengujian unit menyeluruh yang menjalankan inferensi pada model menggunakan library C++.
Mendapatkan perangkat yang didukung
Aplikasi contoh yang akan kita gunakan telah diuji di perangkat berikut:
- Arduino Nano 33 BLE Sense (menggunakan Arduino IDE)
- SparkFun Edge (membangun langsung dari sumber)
- Kit Discovery STM32F746 (menggunakan Mbed)
- Adafruit EdgeBadge (menggunakan Arduino IDE)
- Kit Adafruit TensorFlow Lite for Microcontrollers (menggunakan Arduino IDE)
- Adafruit Circuit Playground Bluefruit (menggunakan Arduino IDE)
- Espressif ESP32-DevKitC (menggunakan ESP IDF)
- Espressif ESP-EYE (menggunakan ESP IDF)
Pelajari lebih lanjut platform yang didukung di TensorFlow Lite untuk Mikrokontroler.
Melatih model
Gunakan train.py untuk pelatihan model hello world untuk pengenalan sinwave
Jalankan: bazel build tensorflow/lite/micro/examples/hello_world:train
bazel-bin/tensorflow/lite/micro/examples/hello_world/train --save_tf_model
--save_dir=/tmp/model_created/
Menjalankan inferensi
Untuk menjalankan model di perangkat Anda, kami akan menelusuri petunjuk di
README.md
:
Bagian berikut ini akan memandu pengujian unit evaluate_test.cc
, contoh yang menunjukkan cara menjalankan inferensi menggunakan TensorFlow Lite untuk Mikrokontroler. Model ini akan memuat model dan menjalankan inferensi beberapa kali.
1. Menyertakan header library
Untuk menggunakan library TensorFlow Lite untuk Mikrokontroler, kita harus menyertakan file header berikut:
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
micro_mutable_op_resolver.h
menyediakan operasi yang digunakan oleh penafsir untuk menjalankan model.micro_error_reporter.h
menghasilkan informasi debug.micro_interpreter.h
berisi kode untuk memuat dan menjalankan model.schema_generated.h
berisi skema untuk format file modelFlatBuffer
TensorFlow Lite.version.h
menyediakan informasi pembuatan versi untuk skema TensorFlow Lite.
2. Sertakan header model
Penafsir TensorFlow Lite untuk Mikrokontroler mengharapkan model
disediakan sebagai array C++. Model ditentukan dalam file model.h
dan model.cc
.
Header disertakan dengan baris berikut:
#include "tensorflow/lite/micro/examples/hello_world/model.h"
3. Menyertakan header framework pengujian unit
Untuk membuat pengujian unit, kami menyertakan framework pengujian unit TensorFlow Lite untuk Mikrokontroler dengan menyertakan baris berikut:
#include "tensorflow/lite/micro/testing/micro_test.h"
Pengujian ditentukan menggunakan makro berikut:
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
. // add code here
.
}
TF_LITE_MICRO_TESTS_END
Sekarang kita membahas kode yang disertakan dalam makro di atas.
4. Menyiapkan logging
Untuk menyiapkan logging, pointer tflite::ErrorReporter
dibuat menggunakan pointer
ke instance tflite::MicroErrorReporter
:
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = µ_error_reporter;
Variabel ini akan diteruskan ke penafsir, yang memungkinkannya menulis log. Karena mikrokontroler sering memiliki berbagai mekanisme untuk logging, implementasi tflite::MicroErrorReporter
dirancang agar disesuaikan dengan
perangkat khusus Anda.
5. Memuat model
Dalam kode berikut, model dibuat menggunakan data dari array char
, g_model
, yang dideklarasikan dalam model.h
. Kemudian kita memeriksa model untuk memastikan
versi skemanya kompatibel dengan versi yang kita gunakan:
const tflite::Model* model = ::tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.\n",
model->version(), TFLITE_SCHEMA_VERSION);
}
6. Membuat instance operasi resolver
Instance
MicroMutableOpResolver
dideklarasikan. Ini akan digunakan oleh penafsir untuk mendaftarkan dan mengakses operasi yang digunakan oleh model:
using HelloWorldOpResolver = tflite::MicroMutableOpResolver<1>;
TfLiteStatus RegisterOps(HelloWorldOpResolver& op_resolver) {
TF_LITE_ENSURE_STATUS(op_resolver.AddFullyConnected());
return kTfLiteOk;
MicroMutableOpResolver
memerlukan parameter template yang menunjukkan jumlah
operasi yang akan didaftarkan. Fungsi RegisterOps
mendaftarkan operasi
dengan resolver.
HelloWorldOpResolver op_resolver;
TF_LITE_ENSURE_STATUS(RegisterOps(op_resolver));
7. Alokasikan memori
Kita perlu mengalokasikan sejumlah memori tertentu untuk input, output, dan array menengah. Parameter ini disediakan sebagai array uint8_t
dengan ukuran tensor_arena_size
:
const int tensor_arena_size = 2 * 1024;
uint8_t tensor_arena[tensor_arena_size];
Ukuran yang diperlukan akan bergantung pada model yang Anda gunakan, dan mungkin perlu ditentukan melalui eksperimen.
8. Membuat instance penerjemah
Kita membuat instance tflite::MicroInterpreter
, dengan meneruskan variabel yang dibuat sebelumnya:
tflite::MicroInterpreter interpreter(model, resolver, tensor_arena,
tensor_arena_size, error_reporter);
9. Alokasikan tensor
Kita memberi tahu penafsir untuk mengalokasikan memori dari tensor_arena
untuk tensor model:
interpreter.AllocateTensors();
10. Validasi bentuk input
Instance MicroInterpreter
dapat memberi kita pointer ke tensor input
model dengan memanggil .input(0)
, dengan 0
mewakili tensor
input pertama (dan satu-satunya):
// Obtain a pointer to the model's input tensor
TfLiteTensor* input = interpreter.input(0);
Kemudian, kita memeriksa tensor ini untuk mengonfirmasi bahwa bentuk dan jenisnya adalah yang diharapkan:
// Make sure the input has the properties we expect
TF_LITE_MICRO_EXPECT_NE(nullptr, input);
// The property "dims" tells us the tensor's shape. It has one element for
// each dimension. Our input is a 2D tensor containing 1 element, so "dims"
// should have size 2.
TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size);
// The value of each element gives the length of the corresponding tensor.
// We should expect two single element tensors (one is contained within the
// other).
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
// The input is a 32 bit floating point value
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, input->type);
Nilai enum kTfLiteFloat32
adalah referensi ke salah satu jenis data TensorFlow Lite, dan ditentukan dalam common.h
.
11. Berikan nilai input
Untuk memberikan input ke model, kita tetapkan konten tensor input, sebagai berikut:
input->data.f[0] = 0.;
Dalam hal ini, kita memasukkan nilai floating point yang mewakili 0
.
12. Menjalankan model
Untuk menjalankan model, kita dapat memanggil Invoke()
pada instance tflite::MicroInterpreter
:
TfLiteStatus invoke_status = interpreter.Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed\n");
}
Kita dapat memeriksa nilai yang ditampilkan, TfLiteStatus
, untuk menentukan apakah operasi berhasil atau tidak. Nilai TfLiteStatus
yang mungkin, yang ditentukan dalam
common.h
,
adalah kTfLiteOk
dan kTfLiteError
.
Kode berikut menegaskan bahwa nilainya adalah kTfLiteOk
, yang berarti inferensi berhasil dijalankan.
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
13. Mendapatkan output
Tensor output model dapat diperoleh dengan memanggil output(0)
pada
tflite::MicroInterpreter
, dengan 0
mewakili tensor output
pertama (dan satu-satunya).
Dalam contoh, output model adalah nilai floating point tunggal yang terdapat dalam tensor 2D:
TfLiteTensor* output = interpreter.output(0);
TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, output->type);
Kita dapat membaca nilai langsung dari tensor output dan menegaskan bahwa nilai tersebut sesuai dengan yang diharapkan:
// Obtain the output value from the tensor
float value = output->data.f[0];
// Check that the output value is within 0.05 of the expected value
TF_LITE_MICRO_EXPECT_NEAR(0., value, 0.05);
14. Menjalankan inferensi lagi
Sisa kode menjalankan inferensi beberapa kali lagi. Pada setiap instance, kita menetapkan nilai pada tensor input, memanggil penafsir, dan membaca hasil dari tensor output:
input->data.f[0] = 1.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.841, value, 0.05);
input->data.f[0] = 3.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.141, value, 0.05);
input->data.f[0] = 5.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(-0.959, value, 0.05);