Chuyển đổi mô hình TensorFlow

Trang này mô tả cách chuyển đổi mô hình TensorFlow thành mô hình TensorFlow Lite (một định dạng FlatBuffer được tối ưu hóa được xác định bằng đuôi tệp .tflite) bằng trình chuyển đổi TensorFlow Lite.

Quy trình chuyển đổi

Sơ đồ bên dưới minh hoạ quy trình hoạt động cấp cao để chuyển đổi mô hình của bạn:

Quy trình công việc của trình chuyển đổi TFLite

Hình 1. Quy trình làm việc của trình chuyển đổi.

Bạn có thể chuyển đổi mô hình của mình bằng một trong các tuỳ chọn sau:

  1. API Python (nên dùng): API này cho phép bạn tích hợp lượt chuyển đổi vào quy trình phát triển, áp dụng các biện pháp tối ưu hoá, thêm siêu dữ liệu và nhiều tác vụ khác giúp đơn giản hoá quá trình chuyển đổi.
  2. Dòng lệnh: Thao tác này chỉ hỗ trợ chuyển đổi mô hình cơ bản.

API Python

Đoạn mã trợ giúp: Để tìm hiểu thêm về API chuyển đổi TensorFlow Lite, hãy chạy print(help(tf.lite.TFLiteConverter)).

Chuyển đổi mô hình TensorFlow bằng tf.lite.TFLiteConverter. Mô hình TensorFlow được lưu trữ bằng định dạng savedModel và được tạo bằng API tf.keras.* cấp cao (mô hình Keras) hoặc API tf.* cấp thấp (từ đó bạn tạo các hàm cụ thể). Do đó, bạn có 3 lựa chọn sau (ví dụ như trong các phần tiếp theo):

Ví dụ sau cho thấy cách chuyển đổi SavedModel thành mô hình TensorFlow Lite.

import tensorflow as tf

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)

Chuyển đổi một mô hình Keras

Ví dụ sau cho thấy cách chuyển đổi mô hình Keras thành mô hình TensorFlow Lite.

import tensorflow as tf

# Create a model using high-level tf.keras.* APIs
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(units=1, input_shape=[1]),
    tf.keras.layers.Dense(units=16, activation='relu'),
    tf.keras.layers.Dense(units=1)
])
model.compile(optimizer='sgd', loss='mean_squared_error') # compile the model
model.fit(x=[-1, 0, 1], y=[-3, -1, 1], epochs=5) # train the model
# (to generate a SavedModel) tf.saved_model.save(model, "saved_model_keras_dir")

# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)

Chuyển đổi các hàm cụ thể

Ví dụ sau cho thấy cách chuyển đổi các hàm cụ thể thành mô hình TensorFlow Lite.

import tensorflow as tf

# Create a model using low-level tf.* APIs
class Squared(tf.Module):
  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
  def __call__(self, x):
    return tf.square(x)
model = Squared()
# (ro run your model) result = Squared(5.0) # This prints "25.0"
# (to generate a SavedModel) tf.saved_model.save(model, "saved_model_tf_dir")
concrete_func = model.__call__.get_concrete_function()

# Convert the model.

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func],
                                                            model)
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)

Tính năng khác

Lỗi chuyển đổi

Sau đây là các lỗi chuyển đổi phổ biến và giải pháp cho các lỗi đó:

Công cụ dòng lệnh

Nếu bạn đã cài đặt TensorFlow 2.x từ pip, hãy dùng lệnh tflite_convert. Để xem tất cả các cờ có sẵn, hãy sử dụng lệnh sau:

$ tflite_convert --help

`--output_file`. Type: string. Full path of the output file.
`--saved_model_dir`. Type: string. Full path to the SavedModel directory.
`--keras_model_file`. Type: string. Full path to the Keras H5 model file.
`--enable_v1_converter`. Type: bool. (default False) Enables the converter and flags used in TF 1.x instead of TF 2.x.

You are required to provide the `--output_file` flag and either the `--saved_model_dir` or `--keras_model_file` flag.

Nếu không tải nguồn TensorFlow 2.x và muốn chạy trình chuyển đổi từ nguồn đó mà không cần tạo và cài đặt gói, bạn có thể thay thế "tflite_convert" bằng "bazel run tensorflow/lite/python:tflite_convert --" trong lệnh.

Chuyển đổi SaveModel

tflite_convert \
  --saved_model_dir=/tmp/mobilenet_saved_model \
  --output_file=/tmp/mobilenet.tflite

Chuyển đổi mô hình Keras H5

tflite_convert \
  --keras_model_file=/tmp/mobilenet_keras_model.h5 \
  --output_file=/tmp/mobilenet.tflite