במאמר הזה נסביר איך לאמן מודל ולהריץ הסקה באמצעות מיקרו-בקר.
הדוגמה של Hello World
הדוגמה של Hello World נועדה להדגים את העקרונות הבסיסיים המוחלטים לשימוש ב-TensorFlow Lite למיקרו-בקרים. אנחנו מאמנים ומפעילים מודל שמשכפל פונקציית סינוס, כלומר, הוא לוקח מספר יחיד כקלט, ומפיק את ערך הסינוס של המספר. כשהוא מופעל במיקרו-בקר, החיזויים שלו משמשים כדי להבהב נורות LED או לשלוט באנימציה.
זרימת העבודה מקצה לקצה כוללת את השלבים הבאים:
- אימון מודל (ב-Python): קובץ פיתון כדי לאמן, להמיר ולבצע אופטימיזציה של מודל לשימוש במכשיר.
- הסקת מסקנות (ב-C++ 17): בדיקה של יחידה מקצה לקצה שמריצה מסקנות על המודל באמצעות ספריית C++.
השג מכשיר נתמך
האפליקציה לדוגמה שנשתמש בה נבדקה במכשירים הבאים:
- Arduino Nano 33 BLE Sense (באמצעות סביבת פיתוח משולבת (IDE) של Arduino)
- SparkFun Edge (פיתוח ישירות מהמקור)
- ערכת Discovery STM32F746 (באמצעות Mbed)
- Adafruit EdgeBadge (באמצעות סביבת פיתוח משולבת (IDE) ב-Arduino)
- ערכת Adafruit TensorFlow Lite למיקרו-בקרים (באמצעות Arduino IDE)
- Blufruit Circuit Playground Bluefruit (באמצעות Arduino IDE)
- Espressif ESP32-DevKitC (באמצעות ESP IDF)
- Espressif ESP-EYE (באמצעות ESP IDF)
מידע נוסף על הפלטפורמות הנתמכות זמין במאמר TensorFlow Lite למיקרו-בקרים.
אימון מודל
השתמשו ב-train.py שלום לאימון מודלים של עולם לזיהוי sinwave
מריצים את: bazel build tensorflow/lite/micro/examples/hello_world:train
bazel-bin/tensorflow/lite/micro/examples/hello_world/train --save_tf_model
--save_dir=/tmp/model_created/
הרצת הסקת
כדי להפעיל את המודל במכשיר, נעבור על ההוראות בREADME.md
:
בקטעים הבאים מוסבר על בדיקת היחידה בדוגמה evaluate_test.cc
, שמדגימה איך להריץ הסקת מסקנות באמצעות TensorFlow Lite למיקרו-בקרים. היא טוענת את המודל ומריצה מסקנות כמה פעמים.
1. הכללת כותרות הספרייה
כדי להשתמש ב-TensorFlow Lite לספריית מיקרו-בקרים, עלינו לכלול את קובצי הכותרת הבאים:
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
micro_mutable_op_resolver.h
מספקות את הפעולות שמשמשות את התרגום כדי להריץ את המודל.micro_error_reporter.h
מידע על תוצאות ניפוי הבאגים בפלט.micro_interpreter.h
מכיל קוד לטעינה ולהרצה של מודלים.schema_generated.h
מכיל את הסכימה של פורמט הקובץ של מודל TensorFlow Lite,FlatBuffer
.version.h
מספק מידע על ניהול גרסאות של סכימת TensorFlow Lite.
2. הכללת כותרת המודל
רכיב התרגום של TensorFlow Lite for Microcontrollers מצפה שהמודל יסופק כמערך C++. המודל מוגדר בקבצים model.h
ו-model.cc
.
הכותרת כלולה בשורה הבאה:
#include "tensorflow/lite/micro/examples/hello_world/model.h"
3. הכללת הכותרת של מסגרת בדיקת היחידה
כדי ליצור בדיקת יחידה, אנחנו כוללים את מסגרת הבדיקה של יחידת TensorFlow Lite למיקרו-בקרים, באמצעות השורה הבאה:
#include "tensorflow/lite/micro/testing/micro_test.h"
הבדיקה מוגדרת באמצעות פקודות המאקרו הבאות:
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
. // add code here
.
}
TF_LITE_MICRO_TESTS_END
כעת נדון בקוד שנכלל במאקרו שלמעלה.
4. הגדרה של רישום ביומן
כדי להגדיר רישום ביומן, נוצר מצביע tflite::ErrorReporter
באמצעות מצביע למופע של tflite::MicroErrorReporter
:
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = µ_error_reporter;
המשתנה הזה יועבר למתרגם, וכך הוא יוכל לכתוב יומנים. לרוב, למיקרו-בקרים יש מגוון מנגנונים לרישום ביומן, ולכן ההטמעה של tflite::MicroErrorReporter
תוכננה להיות מותאמת אישית למכשיר הספציפי שלכם.
5. טעינת מודל
בקוד הבא, המודל נוצר באמצעות נתונים ממערך char
, g_model
, שהוצהר ב-model.h
. לאחר מכן אנחנו בודקים את המודל כדי לוודא שגרסת הסכימה שלו תואמת לגרסה שבה אנחנו משתמשים:
const tflite::Model* model = ::tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.\n",
model->version(), TFLITE_SCHEMA_VERSION);
}
6. יצירת מקודד פעולות
מצהירים על מופע של MicroMutableOpResolver
. זה ישמש את המתרגם כדי לרשום את הפעולות שהמודל משתמש בהן ולגשת אליהן:
using HelloWorldOpResolver = tflite::MicroMutableOpResolver<1>;
TfLiteStatus RegisterOps(HelloWorldOpResolver& op_resolver) {
TF_LITE_ENSURE_STATUS(op_resolver.AddFullyConnected());
return kTfLiteOk;
צריך להגדיר פרמטר של תבנית בשדה MicroMutableOpResolver
, שמציין את מספר הפעולות שיירשמו. הפונקציה RegisterOps
רושמת את הפעולות עם המקודד.
HelloWorldOpResolver op_resolver;
TF_LITE_ENSURE_STATUS(RegisterOps(op_resolver));
7. הקצאת זיכרון
אנחנו צריכים להקצות מראש כמות מסוימת של זיכרון למערכי קלט, פלט ומערכי ביניים. מסופק כמערך uint8_t
בגודל tensor_arena_size
:
const int tensor_arena_size = 2 * 1024;
uint8_t tensor_arena[tensor_arena_size];
הגודל הנדרש תלוי במודל שבו משתמשים, וייתכן שיהיה צורך לקבוע אותו באמצעות ניסויים.
8. תרגום מיידי של מתרגם
אנחנו יוצרים מכונת tflite::MicroInterpreter
ומעבירים את המשתנים שנוצרו קודם לכן:
tflite::MicroInterpreter interpreter(model, resolver, tensor_arena,
tensor_arena_size, error_reporter);
9. הקצאת Tensor
אנחנו מנחים את המתרגם/ת להקצות זיכרון מ-tensor_arena
עבור ה-tensors של המודל:
interpreter.AllocateTensors();
10. אימות צורת הקלט
המכונה MicroInterpreter
יכולה לספק לנו מצביע ל-tensor של הקלט של המודל באמצעות קריאה ל-.input(0)
, כאשר 0
מייצג את ה-tensor של הקלט הראשון (והיחיד):
// Obtain a pointer to the model's input tensor
TfLiteTensor* input = interpreter.input(0);
לאחר מכן אנחנו בודקים את החלק הזה כדי לוודא שהצורה והסוג שלו הם הציפיות שלנו:
// Make sure the input has the properties we expect
TF_LITE_MICRO_EXPECT_NE(nullptr, input);
// The property "dims" tells us the tensor's shape. It has one element for
// each dimension. Our input is a 2D tensor containing 1 element, so "dims"
// should have size 2.
TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size);
// The value of each element gives the length of the corresponding tensor.
// We should expect two single element tensors (one is contained within the
// other).
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
// The input is a 32 bit floating point value
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, input->type);
הערך kTfLiteFloat32
הוא הפניה לאחד מסוגי הנתונים ב-TensorFlow Lite, ומוגדר ב-common.h
.
11. יש לספק ערך קלט
כדי לספק קלט למודל, אנחנו מגדירים את התוכן של tensor הקלט, באופן הבא:
input->data.f[0] = 0.;
במקרה הזה, אנחנו מזינים ערך של נקודה צפה (floating-point) שמייצג את 0
.
12. הרצת המודל
כדי להפעיל את המודל, אנחנו יכולים לקרוא לפונקציה Invoke()
במכונה tflite::MicroInterpreter
:
TfLiteStatus invoke_status = interpreter.Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed\n");
}
אנחנו יכולים לבדוק את הערך המוחזר, TfLiteStatus
, על מנת לקבוע אם ההפעלה הצליחה. הערכים האפשריים של TfLiteStatus
, שמוגדרים ב-common.h
, הם kTfLiteOk
ו-kTfLiteError
.
הקוד הבא טוען שהערך הוא kTfLiteOk
, כלומר ההסקה בוצעה בהצלחה.
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
13. קבלת הפלט
אפשר לקבל את tensor הפלט של המודל על ידי קריאה ל-output(0)
ב-tflite::MicroInterpreter
, כאשר 0
מייצג את tensor הפלט הראשון (והיחיד).
בדוגמה, הפלט של המודל הוא ערך של נקודה צפה (floating-point) יחיד שנמצא בתוך טנזור דו-ממדי:
TfLiteTensor* output = interpreter.output(0);
TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, output->type);
אנחנו יכולים לקרוא את הערך ישירות מ-tensor הפלט ולהצהיר שזה מה שאנחנו מצפים:
// Obtain the output value from the tensor
float value = output->data.f[0];
// Check that the output value is within 0.05 of the expected value
TF_LITE_MICRO_EXPECT_NEAR(0., value, 0.05);
14. הרצת ההסקה שוב
שאר הקוד מבצע הסקת מסקנות מספר פעמים נוספות. בכל מקרה, אנחנו מקצים ערך לקלט ה-tensor, מפעילים את המפענח וקוראים את התוצאה מ-tensor הפלט:
input->data.f[0] = 1.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.841, value, 0.05);
input->data.f[0] = 3.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.141, value, 0.05);
input->data.f[0] = 5.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(-0.959, value, 0.05);