Ten dokument wyjaśnia, jak wytrenować model i przeprowadzać wnioskowanie za pomocą funkcji z mikrokontrolerem.
Przykład Hello World
Witaj świecie Przykład ma zademonstrować absolutne podstawy korzystania z LiteRT dla mikrokontrolerów. Trenujemy i uruchamiamy model, który replikuje funkcję sinus, tj. jako dane wejściowe pobiera jedną liczbę, a na wyjściu generuje sinus. Po wdrożeniu w Służą one do mrugania diod LED lub sterowania animację.
Pełny przepływ pracy obejmuje te kroki:
- Wytrenuj model (w Pythonie): plik języka Python do trenowania, konwersji i optymalizować model pod kątem używania na urządzeniu.
- Uruchom wnioskowanie (w C++ 17): kompleksowy test jednostkowy, który uruchamia wnioskowanie na modelu przy użyciu biblioteki C++.
Korzystanie z obsługiwanego urządzenia
Przykładowa aplikacja, której będziemy używać, została przetestowana na następujących urządzeniach:
- Arduino Nano 33 BLE Sense (przy użyciu Arduino IDE)
- SparkFun Edge (tworzenie bezpośrednio ze źródła)
- Zestaw STM32F746 Discovery (za pomocą Mbed)
- Adafruit EdgeBadge (używając Arduino) IDE)
- Zestaw Adafruit LiteRT do mikrokontrolerów (przy użyciu Arduino IDE)
- Obwód Adafruit Playground Bluefruit (przy użyciu Arduino IDE)
- Espressif ESP32-DevKitC (z użyciem identyfikatora ESP IDF)
- Espressif ESP-EYE (z użyciem identyfikatora ESP IDF)
Więcej informacji o obsługiwanych platformach znajdziesz w LiteRT do mikrokontrolerów.
Trenuj model
Używaj train.py do trenowania modelu Hello World do rozpoznawania fal
Uruchom: bazel build tensorflow/lite/micro/examples/hello_world:train
bazel-bin/tensorflow/lite/micro/examples/hello_world/train --save_tf_model
--save_dir=/tmp/model_created/
Uruchom wnioskowanie
Aby uruchomić model na swoim urządzeniu, wykonamy instrukcje podane na
README.md
:
Cześć, Plik README.md z całego świata
Sekcje poniżej opisują
evaluate_test.cc
,
który pokazuje, jak przeprowadzać wnioskowanie z użyciem LiteRT dla
Mikrokontrolery. Wczytuje model i kilka razy przeprowadza wnioskowanie.
1. Uwzględnij nagłówki biblioteki
Aby korzystać z biblioteki LiteRT dla mikrokontrolerów, musimy uwzględnić w następujące pliki nagłówka:
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
micro_mutable_op_resolver.h
udostępnia operacje, których interpreter używa do uruchomienia modelu.micro_error_reporter.h
zwraca dane debugowania.micro_interpreter.h
zawiera kod do wczytywania i uruchamiania modeli.schema_generated.h
zawiera schemat LiteRT Format pliku modeluFlatBuffer
.version.h
udostępnia informacje o obsłudze wersji schematu LiteRT.
2. Dołącz nagłówek modelu
Interpreter LiteRT dla mikrokontrolerów wymaga, aby model
podana w postaci tablicy C++. Model jest zdefiniowany w plikach model.h
i model.cc
.
Nagłówek zawiera ten wiersz:
#include "tensorflow/lite/micro/examples/hello_world/model.h"
3. Dołącz nagłówek platformy do testów jednostkowych
Aby utworzyć test jednostkowy, uwzględniamy język LiteRT dla: Platformę do testowania jednostkowych mikrokontrolerów, dodając do niej następujący wiersz:
#include "tensorflow/lite/micro/testing/micro_test.h"
Test jest zdefiniowany za pomocą tych makr:
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
. // add code here
.
}
TF_LITE_MICRO_TESTS_END
Teraz omówimy kod zawarty w powyższym makrze.
4. Skonfiguruj logowanie
Aby skonfigurować logowanie, za pomocą wskaźnika tworzony jest wskaźnik tflite::ErrorReporter
do instancji tflite::MicroErrorReporter
:
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = µ_error_reporter;
Ta zmienna zostanie przekazana do interpretera, który pozwoli jej zapisywać
dzienników. Mikrokontrolery często mają różne mechanizmy rejestrowania,
implementacja tflite::MicroErrorReporter
jest przeznaczona do dostosowania
na Twoim urządzeniu.
5. Wczytywanie modelu
W poniższym kodzie utworzono instancję modelu z użyciem danych z tablicy char
,
g_model
, która jest zadeklarowana w model.h
. Następnie sprawdzamy model, aby upewnić się,
wersja schematu jest zgodna z używaną wersją:
const tflite::Model* model = ::tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.\n",
model->version(), TFLITE_SCHEMA_VERSION);
}
6. Program do rozpoznawania operacji tworzenia instancji
O
MicroMutableOpResolver
instancji. Tłumacz będzie używany przez tłumacza do rejestracji i
uzyskać dostęp do operacji używanych przez model:
using HelloWorldOpResolver = tflite::MicroMutableOpResolver<1>;
TfLiteStatus RegisterOps(HelloWorldOpResolver& op_resolver) {
TF_LITE_ENSURE_STATUS(op_resolver.AddFullyConnected());
return kTfLiteOk;
Funkcja MicroMutableOpResolver
wymaga parametru szablonu wskazującego liczbę
działań, które zostaną zarejestrowane. Funkcja RegisterOps
rejestruje operacje
z resolverem.
HelloWorldOpResolver op_resolver;
TF_LITE_ENSURE_STATUS(RegisterOps(op_resolver));
7. Przydziel pamięć
Musimy wstępnie przydzielić pewną ilość pamięci na dane wejściowe, wyjściowe
tablice pośrednie. Ta wartość jest udostępniana w postaci tablicy uint8_t
o rozmiarze
tensor_arena_size
:
const int tensor_arena_size = 2 * 1024;
uint8_t tensor_arena[tensor_arena_size];
Wymagany rozmiar zależy od używanego modelu i może być wymagany co jest określane w ramach eksperymentów.
8. Tworzenie instancji interpretera
Tworzymy instancję tflite::MicroInterpreter
, przekazując zmienne
utworzone wcześniej:
tflite::MicroInterpreter interpreter(model, resolver, tensor_arena,
tensor_arena_size, error_reporter);
9. Przydziel tensory
Mówimy tłumaczowi, aby przydzielić pamięć z tensor_arena
dla:
tensory modelu:
interpreter.AllocateTensors();
10. Zweryfikuj kształt danych wejściowych
Instancja MicroInterpreter
może zapewnić nam wskaźnik do funkcji modelu
tensor wejściowy z wywołaniem funkcji .input(0)
, gdzie 0
reprezentuje pierwszą (i tylko) wartość
tensor wejściowy:
// Obtain a pointer to the model's input tensor
TfLiteTensor* input = interpreter.input(0);
Następnie sprawdzamy tensor, aby potwierdzić, że jego kształt i typ są zgodne oczekiwanie:
// Make sure the input has the properties we expect
TF_LITE_MICRO_EXPECT_NE(nullptr, input);
// The property "dims" tells us the tensor's shape. It has one element for
// each dimension. Our input is a 2D tensor containing 1 element, so "dims"
// should have size 2.
TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size);
// The value of each element gives the length of the corresponding tensor.
// We should expect two single element tensors (one is contained within the
// other).
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
// The input is a 32 bit floating point value
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, input->type);
Wartość wyliczenia kTfLiteFloat32
jest odwołaniem do wartości LiteRT
typów danych, a definicja jest definiowana w
common.h
11. Podaj wartość wejściową
Aby dostarczyć modelowi dane wejściowe, ustawiamy zawartość tensora wejściowego jako następujące:
input->data.f[0] = 0.;
W tym przypadku podajemy liczbę zmiennoprzecinkową reprezentującą 0
.
12. Uruchamianie modelu
Aby uruchomić model, możemy wywołać funkcję Invoke()
na naszej tflite::MicroInterpreter
instancja:
TfLiteStatus invoke_status = interpreter.Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed\n");
}
Możemy sprawdzić zwracaną wartość (TfLiteStatus
), aby określić, czy uruchomienie zostało
udało się. Możliwe wartości parametru TfLiteStatus
, zdefiniowane w
common.h
,
to kTfLiteOk
i kTfLiteError
.
Ten kod potwierdza, że wartość to kTfLiteOk
, co oznacza, że wnioskowanie było
.
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
13. Uzyskiwanie danych wyjściowych
Tensor wyjściowy modelu można uzyskać, wywołując funkcję output(0)
w
tflite::MicroInterpreter
, gdzie 0
to pierwsze (i jedyne) dane wyjściowe
tensora.
W tym przykładzie danymi wyjściowymi modelu jest podana wartość zmiennoprzecinkowa w tensorze 2D:
TfLiteTensor* output = interpreter.output(0);
TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, output->type);
Możemy odczytać wartość bezpośrednio z tensora wyjściowego i stwierdzić, że jest to właśnie oczekujemy:
// Obtain the output value from the tensor
float value = output->data.f[0];
// Check that the output value is within 0.05 of the expected value
TF_LITE_MICRO_EXPECT_NEAR(0., value, 0.05);
14. Ponownie uruchom wnioskowanie
Pozostała część kodu uruchamia wnioskowanie jeszcze kilka razy. W każdej instancji przypisujemy wartość do tensora wejściowego, wywołujemy interpreter i odczytuje wynik z tensora wyjściowego:
input->data.f[0] = 1.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.841, value, 0.05);
input->data.f[0] = 3.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(0.141, value, 0.05);
input->data.f[0] = 5.;
interpreter.Invoke();
value = output->data.f[0];
TF_LITE_MICRO_EXPECT_NEAR(-0.959, value, 0.05);