En este documento, se explica cómo entrenar un modelo y ejecutar inferencias con un microcontrolador.
El ejemplo de Hello World
El ejemplo de Hello World se diseñó para demostrar los conceptos básicos del uso de TensorFlow Lite para microcontroladores. Entrenamos y ejecutamos un modelo que replica una función de seno, es decir, que toma un solo número como entrada y genera el valor seno del número. Cuando se implementan en el microcontrolador, sus predicciones se usan para parpadear las luces LED o controlar una animación.
El flujo de trabajo de extremo a extremo implica los siguientes pasos:
- Entrenar un modelo (en Python): Es un archivo de Python para entrenar, convertir y optimizar un modelo para usarlo en el dispositivo.
- Ejecutar inferencia (en C++ 17): Es una prueba de unidades de extremo a extremo que ejecuta la inferencia en el modelo mediante la biblioteca C++.
Obtén un dispositivo compatible
La aplicación de ejemplo que usaremos se probó en los siguientes dispositivos:
- Arduino Nano 33 BLE Sense (con el IDE de Arduino)
- SparkFun Edge (compilación directamente desde la fuente)
- Kit de descubrimiento STM32F746 (con Mbed)
- Adafruit EdgeBadge (con el IDE de Arduino)
- Kit TensorFlow Lite de Adafruit para microcontroladores (con el IDE de Arduino)
- Circuit Playground Bluefruit de Adafruit (con el IDE de Arduino)
- Espressif ESP32-DevKitC (con IDF de ESP)
- Espressif ESP-EYE (con IDF de ESP)
Obtén más información sobre las plataformas compatibles en TensorFlow Lite para microcontroladores.
Entrenar un modelo
Usa train.py para el entrenamiento del modelo Hello World para el reconocimiento sinonda.
Ejecuta: bazel build tensorflow/lite/micro/examples/hello_world:train
bazel-bin/tensorflow/lite/micro/examples/hello_world/train --save_tf_model
--save_dir=/tmp/model_created/
Ejecuta la inferencia
Para ejecutar el modelo en tu dispositivo, revisaremos las instrucciones en README.md
:
En las siguientes secciones, se explica la prueba de unidades evaluate_test.cc
del ejemplo que demuestra cómo ejecutar la inferencia con TensorFlow Lite para microcontroladores. Carga el modelo y ejecuta inferencias varias veces.
1. Cómo incluir los encabezados de la biblioteca
Para usar la biblioteca de TensorFlow Lite para microcontroladores, debemos incluir los siguientes archivos de encabezado:
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
micro_mutable_op_resolver.h
proporciona las operaciones que usa el intérprete para ejecutar el modelo.micro_error_reporter.h
muestra la información de depuración.micro_interpreter.h
contiene el código para cargar y ejecutar modelos.schema_generated.h
contiene el esquema del formato de archivo del modeloFlatBuffer
de TensorFlow Lite.version.h
proporciona información sobre el control de versiones para el esquema de TensorFlow Lite.
2. Incluye el encabezado del modelo
El intérprete de TensorFlow Lite para microcontroladores espera que el modelo se proporcione como un array de C++. El modelo se define en los archivos model.h
y model.cc
.
El encabezado se incluye con la siguiente línea:
#include "tensorflow/lite/micro/examples/hello_world/model.h"
3. Incluye el encabezado del framework de prueba de unidades
Para crear una prueba de unidades, incluimos el framework de prueba de unidades de TensorFlow Lite para microcontroladores mediante la siguiente línea:
#include "tensorflow/lite/micro/testing/micro_test.h"
La prueba se define con las siguientes macros:
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
. // add code here
.
}
TF_LITE_MICRO_TESTS_END
A continuación, analizaremos el código incluido en la macro anterior.
4. Configura el registro
Para configurar el registro, se crea un puntero tflite::ErrorReporter
con un puntero a una instancia tflite::MicroErrorReporter
:
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = µ_error_reporter;
Esta variable se pasará al intérprete, lo que le permitirá escribir registros. Dado que los microcontroladores suelen tener una variedad de mecanismos de registro, la implementación de tflite::MicroErrorReporter
está diseñada para personalizarse según tu dispositivo en particular.
5. Carga un modelo
En el siguiente código, se crean instancias del modelo con datos de un array char
, g_model
, que se declara en model.h
. Luego, verificamos el modelo para asegurarnos de que la versión de su esquema sea compatible con la versión que usamos:
const tflite::Model* model = ::tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.\n",
model->version(), TFLITE_SCHEMA_VERSION);
}
6. Crear una instancia de agente de resolución de operaciones
Se declara una instancia MicroMutableOpResolver
. El intérprete la usará para registrar las operaciones que usa el modelo y
acceder a ellas:
using HelloWorldOpResolver = tflite::MicroMutableOpResolver<1>;
TfLiteStatus RegisterOps(HelloWorldOpResolver& op_resolver) {
TF_LITE_ENSURE_STATUS(op_resolver.AddFullyConnected());
return kTfLiteOk;
MicroMutableOpResolver
requiere un parámetro de plantilla que indique la cantidad de operaciones que se registrarán. La función RegisterOps
registra las operaciones con el agente de resolución.
HelloWorldOpResolver op_resolver;
TF_LITE_ENSURE_STATUS(RegisterOps(op_resolver));
7. Asignar memoria
Debemos asignar previamente una cantidad determinada de memoria para la entrada, la salida y los arreglos intermedios. Esto se proporciona como un array uint8_t
de tamaño tensor_arena_size
:
const int tensor_arena_size = 2 * 1024;
uint8_t tensor_arena[tensor_arena_size];
El tamaño requerido dependerá del modelo que uses y es posible que se deba determinar mediante la experimentación.
8. Crear una instancia de intérprete
Creamos una instancia de tflite::MicroInterpreter
y pasamos las variables creadas anteriormente:
tflite::MicroInterpreter interpreter(model, resolver, tensor_arena,
tensor_arena_size, error_reporter);
9. Asignar tensores
Le indicamos al intérprete que asigne memoria desde tensor_arena
para los tensores del modelo:
interpreter.AllocateTensors();
10. Validar forma de entrada
La instancia MicroInterpreter
puede proporcionarnos un puntero para el tensor de entrada del modelo llamando a .input(0)
, en el que 0
representa el primer (y único) tensor de entrada:
// Obtain a pointer to the model's input tensor
TfLiteTensor* input = interpreter.input(0);
Luego, inspeccionamos este tensor para confirmar que su forma y tipo son los que esperamos:
// Make sure the input has the properties we expect
TF_LITE_MICRO_EXPECT_NE(nullptr, input);
// The property "dims" tells us the tensor's shape. It has one element for
// each dimension. Our input is a 2D tensor containing 1 element, so "dims"
// should have size 2.
TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size);
// The value of each element gives the length of the corresponding tensor.
// We should expect two single element tensors (one is contained within the
// other).
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
// The input is a 32 bit floating point value
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, input->type);
El valor enum kTfLiteFloat32
es una referencia a uno de los tipos de datos de TensorFlow Lite
y se define en
common.h
.
11. Proporciona un valor de entrada
Para proporcionar una entrada al modelo, configuramos el contenido del tensor de entrada de la siguiente manera:
input->data.f[0] = 0.;
En este caso, ingresamos un valor de punto flotante que representa 0
.
12. Ejecuta el modelo
Para ejecutar el modelo, podemos llamar a Invoke()
en nuestra instancia tflite::MicroInterpreter
:
TfLiteStatus invoke_status = interpreter.Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed\n");
}
Podemos verificar el valor que se muestra, un TfLiteStatus
, para determinar si la ejecución fue
correcta. Los valores posibles de TfLiteStatus
, definidos en common.h
, son kTfLiteOk
y kTfLiteError
.
El siguiente código confirma que el valor es kTfLiteOk
, lo que significa que la inferencia se ejecutó de forma correcta.
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
13. Cómo obtener el resultado
El tensor de salida del modelo se puede obtener llamando a output(0)
en tflite::MicroInterpreter
, donde 0
representa el primer (y único) tensor de salida.
En el ejemplo, la salida del modelo es un valor de punto flotante único dentro de un tensor 2D:
TfLiteTensor* output = interpreter.output(0);
TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, output->type);
Podemos leer el valor directamente desde el tensor de salida y confirmar que es lo que esperamos:
// Obtain the output value from the tensor
float value = output->data.f[0];
// Check that the output value is within 0.05 of the expected value
TF_LITE_MICRO_EXPECT_NEAR(0., value, 0.05);
14. Vuelve a ejecutar la inferencia
El resto del código ejecuta inferencias varias veces más. En cada instancia, asignamos un valor al tensor de entrada, invocamos el intérprete y leemos el resultado del tensor de salida:
input->data.f[0] = 1.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.841, value, 0.05);
input->data.f[0] = 3.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.141, value, 0.05);
input->data.f[0] = 5.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(-0.959, value, 0.05);