En este documento, se explica cómo entrenar un modelo y ejecutar inferencias con un o un microcontrolador.
El ejemplo de Hello World
El Hello World Este ejemplo se diseñó para demostrar los conceptos básicos del uso de LiteRT. para microcontroladores. Entrenamos y ejecutamos un modelo que replica una función seno Es decir, toma un solo número como su entrada y genera como salida los valores sine. Cuando se implementa en el microcontrolador, sus predicciones se usan para parpadear las luces LED o controlar una animación.
El flujo de trabajo de extremo a extremo implica los siguientes pasos:
- Entrena un modelo (en Python): un archivo de Python para entrenar y convertir. y optimizar un modelo para usarlo en el dispositivo.
- Ejecutar inferencia (en C++ 17): una prueba de unidades de extremo a extremo ejecuta inferencias en el modelo con la biblioteca C++.
Obtén un dispositivo compatible
La aplicación de ejemplo que vamos a usar se probó en los siguientes dispositivos:
- Arduino Nano 33 BLE Sense (con el IDE de Arduino)
- SparkFun Edge (compilación directamente) desde la fuente)
- Kit STM32F746 Discovery (con Mbed)
- Adafruit EdgeBadge (con Arduino) (IDE)
- Kit de Adafruit LiteRT para microcontroladores (con el IDE de Arduino)
- Circuit Playground Bluefruit de Adafruit (con el IDE de Arduino)
- ESP32-DevKitC de Espressif (con IDF de ESP)
- ESP-EYE de Espressif (con IDF de ESP)
Obtén más información sobre las plataformas compatibles en LiteRT para microcontroladores.
Entrenar un modelo
Usa train.py para el entrenamiento de modelos de Hello World para el reconocimiento de sinwave
Ejecutar: bazel build tensorflow/lite/micro/examples/hello_world:train
bazel-bin/tensorflow/lite/micro/examples/hello_world/train --save_tf_model
--save_dir=/tmp/model_created/
Ejecuta la inferencia
Para ejecutar el modelo en tu dispositivo, seguiremos las instrucciones en la
README.md
:
En las siguientes secciones, se explican
evaluate_test.cc
:
prueba de unidades que demuestra cómo ejecutar inferencias con LiteRT para
Microcontroladores Carga el modelo y ejecuta inferencias varias veces.
1. Cómo incluir los encabezados de la biblioteca
Si quieres usar la biblioteca LiteRT para microcontroladores, debemos incluir lo siguiente: siguientes archivos de encabezado:
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
micro_mutable_op_resolver.h
Proporciona las operaciones que usa el intérprete para ejecutar el modelo.micro_error_reporter.h
genera información de depuración.micro_interpreter.h
contiene código para cargar y ejecutar modelos.schema_generated.h
Contiene el esquema de LiteRT. Formato de archivo del modeloFlatBuffer
.version.h
proporciona información sobre el control de versiones para el esquema LiteRT.
2. Incluye el encabezado del modelo
El intérprete de LiteRT para microcontroladores espera que el modelo sea
proporcionado como un array de C++. El modelo se define en archivos model.h
y model.cc
.
El encabezado se incluye con la siguiente línea:
#include "tensorflow/lite/micro/examples/hello_world/model.h"
3. Incluye el encabezado del framework de prueba de unidades
Para crear una prueba de unidades, incluimos el LiteRT Framework de prueba de unidades de microcontroladores con la siguiente línea:
#include "tensorflow/lite/micro/testing/micro_test.h"
La prueba se define con las siguientes macros:
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
. // add code here
.
}
TF_LITE_MICRO_TESTS_END
Ahora, analizaremos el código incluido en la macro anterior.
4. Configura el registro
Para configurar el registro, se crea un puntero tflite::ErrorReporter
con un puntero
a una instancia de tflite::MicroErrorReporter
:
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = µ_error_reporter;
Esta variable se pasará al intérprete, lo que le permite escribir
los registros del sistema operativo. Dado que los microcontroladores suelen tener diversos mecanismos de registro,
implementación de tflite::MicroErrorReporter
está diseñada para personalizarse
tu dispositivo en particular.
5. Carga un modelo
En el siguiente código, se crea una instancia del modelo con datos de un array char
.
g_model
, que se declara en model.h
Luego, comprobamos el modelo para asegurarnos de que
es compatible con la versión que estamos usando:
const tflite::Model* model = ::tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.\n",
model->version(), TFLITE_SCHEMA_VERSION);
}
6. Crear una instancia de agente de resolución de operaciones
R
MicroMutableOpResolver
declara la instancia. El intérprete usará esta información para registrar y
accedan a las operaciones que usa el modelo:
using HelloWorldOpResolver = tflite::MicroMutableOpResolver<1>;
TfLiteStatus RegisterOps(HelloWorldOpResolver& op_resolver) {
TF_LITE_ENSURE_STATUS(op_resolver.AddFullyConnected());
return kTfLiteOk;
MicroMutableOpResolver
requiere un parámetro de plantilla que indique la cantidad.
de operaciones que se registrarán. La función RegisterOps
registra las operaciones
con el agente de resolución.
HelloWorldOpResolver op_resolver;
TF_LITE_ENSURE_STATUS(RegisterOps(op_resolver));
7. Asignar memoria
Debemos asignar previamente
una determinada cantidad de memoria
arrays intermedios. Se proporciona como un array de tamaño uint8_t
.
tensor_arena_size
const int tensor_arena_size = 2 * 1024;
uint8_t tensor_arena[tensor_arena_size];
El tamaño requerido dependerá del modelo que estés usando y es posible que debas determinado por la experimentación.
8. Crear una instancia de intérprete
Creamos una instancia tflite::MicroInterpreter
y pasamos las variables
creado anteriormente:
tflite::MicroInterpreter interpreter(model, resolver, tensor_arena,
tensor_arena_size, error_reporter);
9. Asignar tensores
Le indicamos al intérprete que asigne memoria desde tensor_arena
para el
tensores del modelo:
interpreter.AllocateTensors();
10. Validar forma de entrada
La instancia MicroInterpreter
puede proporcionarnos un puntero al espacio de nombres
tensor de entrada llamando a .input(0)
, donde 0
representa el primer (y único)
tensor de entrada:
// Obtain a pointer to the model's input tensor
TfLiteTensor* input = interpreter.input(0);
Luego, inspeccionamos este tensor para confirmar que su forma y tipo son los que esperar:
// Make sure the input has the properties we expect
TF_LITE_MICRO_EXPECT_NE(nullptr, input);
// The property "dims" tells us the tensor's shape. It has one element for
// each dimension. Our input is a 2D tensor containing 1 element, so "dims"
// should have size 2.
TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size);
// The value of each element gives the length of the corresponding tensor.
// We should expect two single element tensors (one is contained within the
// other).
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
// The input is a 32 bit floating point value
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, input->type);
El valor enum kTfLiteFloat32
es una referencia a uno de los LiteRT.
tipos de datos y se define en
common.h
11. Proporciona un valor de entrada
Para proporcionar una entrada al modelo, configuramos el contenido del tensor de entrada, como sigue:
input->data.f[0] = 0.;
En este caso, ingresamos un valor de punto flotante que representa 0
.
12. Ejecuta el modelo
Para ejecutar el modelo, podemos llamar a Invoke()
en nuestra tflite::MicroInterpreter
.
instancia:
TfLiteStatus invoke_status = interpreter.Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed\n");
}
Podemos comprobar el valor que se muestra, un TfLiteStatus
, para determinar si la ejecución se
y exitoso. Los valores posibles de TfLiteStatus
, definidos en
common.h
,
son kTfLiteOk
y kTfLiteError
.
El siguiente código confirma que el valor es kTfLiteOk
, lo que significa que la inferencia
se ejecutó correctamente.
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
13. Obtén el resultado
El tensor de salida del modelo se puede obtener llamando a output(0)
en el
tflite::MicroInterpreter
, donde 0
representa la primera (y única) salida.
tensor.
En el ejemplo, el resultado del modelo es un único valor de punto flotante contenido dentro de un tensor 2D:
TfLiteTensor* output = interpreter.output(0);
TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, output->type);
Podemos leer el valor directamente desde el tensor de salida y afirmar que es esperamos que:
// Obtain the output value from the tensor
float value = output->data.f[0];
// Check that the output value is within 0.05 of the expected value
TF_LITE_MICRO_EXPECT_NEAR(0., value, 0.05);
14. Vuelve a ejecutar la inferencia
El resto del código ejecuta la inferencia varias veces más. En cada instancia, asignamos un valor al tensor de entrada, invocamos el intérprete y leemos resultado del tensor de salida:
input->data.f[0] = 1.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.841, value, 0.05);
input->data.f[0] = 3.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.141, value, 0.05);
input->data.f[0] = 5.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(-0.959, value, 0.05);