Konwertowanie modeli TensorFlow

Na tej stronie opisujemy, jak przekonwertować model TensorFlow na model TensorFlow Lite (zoptymalizowany format FlatBuffer identyfikowany przez rozszerzenie pliku .tflite) przy użyciu konwertera TensorFlow Lite.

Przepływ pracy związany z konwersjami

Poniższy schemat przedstawia ogólny przepływ pracy przy konwertowaniu modelu:

Przepływ pracy w konwerterze TFLite

Rysunek 1. Przepływ pracy osoby dokonującej konwersji.

Możesz przekonwertować model, korzystając z jednej z tych opcji:

  1. Python API (zalecany): umożliwia integrację konwersji z procesem programowania, stosowanie optymalizacji, dodawanie metadanych i wiele innych zadań, które upraszczają proces konwersji.
  2. Wiersz poleceń: obsługuje tylko konwersję modelu podstawowego.

Interfejs API Pythona

Kod pomocniczy: aby dowiedzieć się więcej o interfejsie API konwertera TensorFlow Lite, uruchom print(help(tf.lite.TFLiteConverter)).

Przekonwertuj model TensorFlow za pomocą tf.lite.TFLiteConverter. Model TensorFlow jest przechowywany za pomocą formatu SavedModel i generowany za pomocą interfejsów API tf.keras.* wysokiego poziomu (model Keras) lub niskopoziomowych interfejsów API tf.* (z których generujesz konkretne funkcje). W efekcie masz do wyboru 3 opcje (przykłady znajdziesz w kilku kolejnych sekcjach):

Poniższy przykład pokazuje, jak przekonwertować model SavedModel na model TensorFlow Lite.

import tensorflow as tf

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)

Konwertowanie modelu Keras

Poniższy przykład pokazuje, jak przekonwertować model Keras na model TensorFlow Lite.

import tensorflow as tf

# Create a model using high-level tf.keras.* APIs
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(units=1, input_shape=[1]),
    tf.keras.layers.Dense(units=16, activation='relu'),
    tf.keras.layers.Dense(units=1)
])
model.compile(optimizer='sgd', loss='mean_squared_error') # compile the model
model.fit(x=[-1, 0, 1], y=[-3, -1, 1], epochs=5) # train the model
# (to generate a SavedModel) tf.saved_model.save(model, "saved_model_keras_dir")

# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)

Przekonwertuj konkretne funkcje

Poniższy przykład pokazuje, jak przekonwertować funkcje konkretne na model TensorFlow Lite.

import tensorflow as tf

# Create a model using low-level tf.* APIs
class Squared(tf.Module):
  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
  def __call__(self, x):
    return tf.square(x)
model = Squared()
# (ro run your model) result = Squared(5.0) # This prints "25.0"
# (to generate a SavedModel) tf.saved_model.save(model, "saved_model_tf_dir")
concrete_func = model.__call__.get_concrete_function()

# Convert the model.

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func],
                                                            model)
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)

Inne funkcje

  • Zastosuj optymalizacje. Typową optymalizowaniem jest kwantyfikacja po trenowaniu, która może jeszcze bardziej zmniejszyć czas oczekiwania i rozmiar modelu przy minimalnej utracie dokładności.

  • Dodaj metadane, które ułatwiają tworzenie kodu otoki na danej platformie podczas wdrażania modeli na urządzeniach.

Błędy konwersji

Poniżej znajdziesz typowe błędy konwersji i ich rozwiązania:

Narzędzie wiersza poleceń

Jeśli masz zainstalowany TensorFlow 2.x z pip, użyj polecenia tflite_convert. Aby wyświetlić wszystkie dostępne flagi, użyj tego polecenia:

$ tflite_convert --help

`--output_file`. Type: string. Full path of the output file.
`--saved_model_dir`. Type: string. Full path to the SavedModel directory.
`--keras_model_file`. Type: string. Full path to the Keras H5 model file.
`--enable_v1_converter`. Type: bool. (default False) Enables the converter and flags used in TF 1.x instead of TF 2.x.

You are required to provide the `--output_file` flag and either the `--saved_model_dir` or `--keras_model_file` flag.

Jeśli masz pobrane źródło TensorFlow 2.x i chcesz uruchomić z niego konwerter bez kompilowania i instalowania pakietu, możesz w poleceniu zastąpić „tflite_convert” ciągiem „bazel run tensorflow/lite/python:tflite_convert --”.

Konwertowanie obiektu SavedModel

tflite_convert \
  --saved_model_dir=/tmp/mobilenet_saved_model \
  --output_file=/tmp/mobilenet.tflite

Konwertowanie modelu Keras H5

tflite_convert \
  --keras_model_file=/tmp/mobilenet_keras_model.h5 \
  --output_file=/tmp/mobilenet.tflite