自定义运算符

由于 TensorFlow Lite 内置运算符库仅支持有限数量的 TensorFlow 运算符,因此并非每个模型都可以转换。如需了解详情,请参阅运算符兼容性

如需允许转换,用户可以针对 TensorFlow Lite 中不受支持的 TensorFlow 运算符(称为自定义运算符)提供自己的自定义实现。如果您希望将一系列不受支持的(或受支持的)TensorFlow 运算符合并为单个经过优化的一体化自定义运算符,请参阅运算符融合

使用自定义运算符包含四个步骤。

我们来通过一个端到端示例了解一下如何使用自定义运算符 tf.atan(名为 Atan,请参阅创建 TensorFlow 模型)运行模型的端到端示例。该运算符在 TensorFlow 中受支持,但在 TensorFlow Lite 中不受支持。

TensorFlow Text 运算符就是一个自定义运算符的示例。如需查看代码示例,请参阅将 TF Text 转换为 TF Lite 教程。

示例:自定义 Atan 运算符

下面我们来看一个支持 TensorFlow Lite 所不具备的 TensorFlow 运算符的示例。假设我们使用的是 Atan 运算符,并且正在为函数 y = atan(x + offset) 构建非常简单的模型,其中 offset 可以训练。

创建 TensorFlow 模型

以下代码段会训练一个简单的 TensorFlow 模型。此模型仅包含一个名为 Atan 的自定义运算符,该运算符为函数 y = atan(x + offset),其中 offset 可训练。

import tensorflow as tf

# Define training dataset and variables
x = [-8, 0.5, 2, 2.2, 201]
y = [-1.4288993, 0.98279375, 1.2490457, 1.2679114, 1.5658458]
offset = tf.Variable(0.0)

# Define a simple model which just contains a custom operator named `Atan`
@tf.function(input_signature=[tf.TensorSpec.from_tensor(tf.constant(x))])
def atan(x):
  return tf.atan(x + offset, name="Atan")

# Train model
optimizer = tf.optimizers.Adam(0.01)
def train(x, y):
    with tf.GradientTape() as t:
      predicted_y = atan(x)
      loss = tf.reduce_sum(tf.square(predicted_y - y))
    grads = t.gradient(loss, [offset])
    optimizer.apply_gradients(zip(grads, [offset]))

for i in range(1000):
    train(x, y)

print("The actual offset is: 1.0")
print("The predicted offset is:", offset.numpy())
The actual offset is: 1.0
The predicted offset is: 0.99999905

此时,如果您尝试使用默认转换器标志生成 TensorFlow Lite 模型,则会收到以下错误消息:

Error:
error: 'tf.Atan' op is neither a custom op nor a flex op.

转换为 TensorFlow Lite 模型

通过设置转换器属性 allow_custom_ops,创建带有自定义运算符的 TensorFlow Lite 模型,如下所示:

converter = tf.lite.TFLiteConverter.from_concrete_functions([atan.get_concrete_function()], atan)
converter.allow_custom_ops = True
tflite_model = converter.convert()

此时,如果您通过如下命令使用默认解释器运行该测试:

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

您仍然会收到以下错误消息:

Encountered unresolved custom op: Atan.

创建并注册运算符。

#include "third_party/tensorflow/lite/c/c_api.h"
#include "third_party/tensorflow/lite/c/c_api_opaque.h"

TensorFlow Lite 自定义运算符使用简单的纯 C API 进行定义,该 API 由不透明类型 (TfLiteRegistrationExternal) 和相关函数组成。

TfLiteRegistrationExternal 是一个不透明类型:

typedef struct TfLiteRegistrationExternal TfLiteRegistrationExternal;

TfLiteRegistrationExternal 存储运营商的身份和实现。(请注意,运算符与其运算数不同,后者存储在调用运算符的节点的 TF Lite 图节点中。)

此类型的实例是通过调用 TfLiteRegistrationExternalCreate 构造的,并可通过调用 TfLiteRegistrationExternalDelete 销毁。

运营商的身份是通过构造函数 TfLiteRegistrationExternalCreate 的参数来设置的:

TfLiteRegistrationExternal*
TfLiteRegistrationExternalCreate(
    TfLiteBuiltinOperator builtin_code,  // Normally `TfLiteBuiltinCustom`.
    const char* custom_name,  // The name of the custom op.
    int version  // Normally `1` for the first version of a custom op.
);

运算符实现可以使用以下签名定义“方法”。所有这些方法都是可选的,但为了成功评估运算符,运算符实现需要至少定义和设置 PrepareInvoke 方法(使用 setter 函数)。

// Initializes the op from serialized data.
void* Init(TfLiteOpaqueContext* context, const char* buffer, size_t length);

// Deallocates the op.
// The pointer `buffer` is the data previously returned by an Init invocation.
void Free(TfLiteOpaqueContext* context, void* buffer);

// Called when the inputs that this node depends on have been resized.
TfLiteStatus Prepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node);

// Called when the node is executed. (Should read node inputs and write to
// node outputs).
TfLiteStatus Invoke(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node);

// Retrieves the async kernel.
TfLiteAsyncKernel AsyncKernel(TfLiteOpaqueContext* context,
                              TfLiteOpaqueNode* node);

操作实现中的函数名称(对于 C++,为命名空间前缀)不必与上述代码段中的函数名称一致,因为 TF Lite 自定义操作 API 仅使用其地址。names事实上,我们建议您在匿名命名空间中或将其声明为静态函数。

但最好将运算符名称作为命名空间或前缀添加到以下函数名称中:

C++

namespace my_namespace::my_custom_op {
  void* Init(TfLiteOpaqueContext* context,
             const char* buffer, size_t length) { ... }
  // ... plus definitions of Free, Prepare, and Invoke ...
}
      

C

void* MyCustomOpInit(TfLiteOpaqueContext* context,
                     const char* buffer, size_t length) { ... }
// ... plus definitions of MyCustomOpFree, MyCustomOpPrepare, and
// MyCustomOpInvoke.
      

由于这是一个 C API,因此这些“方法”会以 TfLiteRegistrationExternal 类型中的 C 函数指针的形式实现,而这些指针是通过将实现函数的地址传递给相应的 setter 函数 TfLiteRegistrationExternalSetMethodName 来设置的:

void TfLiteRegistrationExternalSetInit(
    TfLiteRegistrationExternal* registration,
    void* (*init)(TfLiteOpaqueContext* context, const char* buffer,
                  size_t length));
void TfLiteRegistrationExternalSetFree(
    TfLiteRegistrationExternal* registration,
    void (*free)(TfLiteOpaqueContext* context, void* data));
void TfLiteRegistrationExternalSetPrepare(
    TfLiteRegistrationExternal* registration,
    TfLiteStatus (*prepare)(TfLiteOpaqueContext* context,
                            TfLiteOpaqueNode* node));
void TfLiteRegistrationExternalSetInvoke(
    TfLiteRegistrationExternal* registration,
    TfLiteStatus (*invoke)(TfLiteOpaqueContext* context,
                           TfLiteOpaqueNode* node));
void TfLiteRegistrationExternalSetAsyncKernel(
    TfLiteRegistrationExternal* registration,
    struct TfLiteAsyncKernel* (*async_kernel)(TfLiteOpaqueContext* context,
                                              TfLiteOpaqueNode* node));

如需详细了解 TfLiteContextTfLiteNode,请参阅 common.hTfLiteContext 提供错误报告工具以及对全局对象(包括所有张量)的访问权限。TfLiteNode 允许运算符实现访问其输入和输出。

当解释器加载模型时,它会针对图中的每个节点调用一次 Init() 方法。如果在图中多次使用某个给定的 Init(),系统会多次调用该运算。对于自定义操作,系统将提供配置缓冲区,其中包含可将参数名称映射到其值的 Flexbuffer。对于内置操作,缓冲区是空的,因为解释器已经解析操作参数。需要状态的内核实现应在此处对其进行初始化,并将所有权转移给调用方。对于每个 Init() 调用,都会有对 Free() 的相应调用,以允许实现处置它们可能在 Init() 中分配的缓冲区。

每当调整输入张量的大小时,解释器都会遍历图,以通知更改的实现。这使他们有机会调整内部缓冲区的大小,检查输入形状和类型的有效性,以及重新计算输出形状。这全部通过 Prepare() 方法完成,并且实现可以使用 TfLiteOpaqueNodeGetUserData(node) 访问其状态。

最后,每次推断运行时,解释器都会遍历调用 Invoke() 方法的图,状态在这里也以 TfLiteOpaqueNodeGetUserData(node) 的形式提供。

要实现自定义操作,可以定义这些“方法”函数,然后定义一个函数,该函数会返回通过调用 TfLiteRegistrationExternalCreate 和相关 setter 方法构造的 TfLiteRegistrationExternal 实例:

C++

namespace my_namespace::my_custom_op {
  namespace {
    void* Init(TfLiteOpaqueContext* context,
               const char* buffer, size_t length) { ... }
    void Free(TfLiteOpaqueContext* context, void* buffer) { ... }
    TfLiteStatus Prepare(TfLiteOpaqueContext* context,
                         TfLiteOpaqueNode* node) { ... }
    TfLiteStatus Invoke(TfLiteOpaqueContext* context,
                        TfLiteOpaqueNode* node) {... }
  };

  const TfLiteRegistrationExternal* MyCustomOpRegistrationExternal() {
    // Singleton instance, intentionally never destroyed.
    static const TfLiteRegistrationExternal* my_custom_op = ()[] {
        TfLiteRegistrationExternal* r =
            TfLiteRegistrationExternalCreate(
                kTfLiteBuiltinCustom, "MyCustomOp", /*version=*/ 1);
        TfLiteRegistrationExternalSetInit(r, Init);
        TfLiteRegistrationExternalSetFree(r, Free);
        TfLiteRegistrationExternalSetPrepare(r, Prepare);
        TfLiteRegistrationExternalSetInvoke(r, Eval);
        return r;
      };
    return my_custom_op;
  }

  const TfLiteRegistration* MyCustomOpRegistration() {
    static const TfLiteRegistration my_custom_op {
      .registration_external = MyCustomOpRegistrationExternal();
    };
    return my_custom_op;
  }
}  // namespace my_namespace
      

C

static void* MyCustomOpInit(TfLiteOpaqueContext* context, const char* buffer,
                     size_t length) { ... }
static void MyCustomOpFree(TfLiteOpaqueContext* context, void* buffer) { ... }
static TfLiteStatus MyCustomOpPrepare(TfLiteOpaqueContext* context,
                                      TfLiteOpaqueNode* node) { ... }
static TfLiteStatus MyCustomOpInvoke(TfLiteOpaqueContext* context,
                                     TfLiteOpaqueNode* node) {... }

static TfLiteRegistrationExternal* MyCustomOpCreate() {
  const TfLiteRegistrationExternal* r =
      TfLiteRegistrationExternalCreate(
          kTfLiteBuiltinCustom, "MyCustomOp", /*version=*/ 1);
  TfLiteRegistrationExternalSetInit(r, MyCustomOpInit);
  TfLiteRegistrationExternalSetFree(r, MyCustomOpFree);
  TfLiteRegistrationExternalSetPrepare(r, MyCustomOpPrepare);
  TfLiteRegistrationExternalSetInvoke(r, MyCustomOpEval);
  return r;
}

const TfLiteRegistrationExternal* MyCustomOpRegistrationExternal() {
  // Singleton instance, intentionally never destroyed.
  static const TfLiteRegistrationExternal* my_custom_op = MyCustomOpCreate();
  return my_custom_op;
}

const TfLiteRegistration MyCustomOpRegistration() {
  static const TfLiteRegistration my_custom_op {
    .registration_external = MyCustomOpRegistrationExternal();
  };
  return my_custom_op;
}
      

请注意,注册不是自动进行的,应明确调用您的 MyCustomOpRegistration 函数(详见下文)。虽然标准 BuiltinOpResolver(可从 :builtin_ops 目标获取)负责处理内置项的注册,但自定义操作必须在单独的自定义库中收集。

在 TensorFlow Lite 运行时中定义内核

如需在 TensorFlow Lite 中使用该操作,我们只需定义两个函数(PrepareEval),并定义第三个函数,用于构造 TfLiteRegistrationExternal

C++

namespace atan_op {
  namespace {
    TfLiteStatus AtanPrepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) {
      TF_LITE_OPAQUE_ENSURE_EQ(context, TfLiteOpaqueNodeNumInputs(node), 1);
      TF_LITE_OPAQUE_ENSURE_EQ(context, TfLiteOpaqueNodeNumOutputs(node), 1);

      const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0);
      TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0);

      int num_dims = TfLiteOpaqueTensorNumDimensions(input);

      TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims);
      for (int i=0; i < num_dims; ++i) {
        output_size->data[i] = input->dims->data[i];
      }

      return TfLiteOpaqueContextResizeTensor(context, output, output_size);
    }

    TfLiteStatus AtanEval(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) {
      const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0);
      TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0);

      float* input_data = static_cast(TfLiteOpaqueTensorData(input));
      float* output_data = static_cast(TfLiteOpaqueTensorData(output));

      size_t count = 1;
      int num_dims = TfLiteOpaqueTensorNumDimensions(input);
      for (int i = 0; i < num_dims; ++i) {
        count *= input->dims->data[i];
      }

      for (size_t i = 0; i < count; ++i) {
        output_data[i] = atan(input_data[i]);
      }
      return kTfLiteOk;
    }
  }  // anonymous namespace

  const TfLiteRegistrationExternal* AtanOpRegistrationExternal() {
    // Singleton instance, intentionally never destroyed.
    static const TfLiteRegistrationExternal* atan_op = ()[] {
        auto* r = TfLiteRegistrationExternalCreate(
            kTfLiteBuiltinCustom, "ATAN", /*version=*/ 1);
        TfLiteRegistrationExternalSetPrepare(r, Prepare);
        TfLiteRegistrationExternalSetInvoke(r, Eval);
        return r;
      };
    return atan_op;
  }

  const TfLiteRegistration AtanOpRegistration() {
    static const TfLiteRegistration atan_op {
      .registration_external = AtanOpRegistrationExternal();
    };
    return atan_op;
  }
}  // namespace atan_op
      

C

static TfLiteStatus AtanPrepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) {
  TF_LITE_OPAQUE_ENSURE_EQ(context, TfLiteOpaqueNodeNumInputs(node), 1);
  TF_LITE_OPAQUE_ENSURE_EQ(context, TfLiteOpaqueNodeNumOutputs(node), 1);

  const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0);
  TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0);

  int num_dims = TfLiteOpaqueTensorNumDimensions(input);

  TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims);
  for (int i = 0; i < num_dims; ++i) {
    output_size->data[i] = input->dims->data[i];
  }

  return TfLiteOpaqueContextResizeTensor(context, output, output_size);
}

static TfLiteStatus AtanEval(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) {
  const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0);
  TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0);

  float* input_data = static_cast(TfLiteOpaqueTensorData(input));
  float* output_data = static_cast(TfLiteOpaqueTensorData(output));

  size_t count = 1;
  int num_dims = TfLiteOpaqueTensorNumDimensions(input);
  for (int i = 0; i < num_dims; ++i) {
    count *= input->dims->data[i];
  }

  for (size_t i = 0; i < count; ++i) {
    output_data[i] = atan(input_data[i]);
  }
  return kTfLiteOk;
}

static const TfLiteRegistrationExternal* AtanOpCreate() {
  TfLiteRegistrationExternal* r = TfLiteRegistrationExternalCreate(
          kTfLiteBuiltinCustom, "ATAN", /*version=*/ 1);
  TfLiteRegistrationExternalSetPrepare(r, Prepare);
  TfLiteRegistrationExternalSetInvoke(r, Eval);
  return r;
}

const TfLiteRegistrationExternal* AtanOpRegistrationExternal() {
  // Singleton instance, intentionally never destroyed.
  static const TfLiteRegistrationExternal* atan_op = AtanOpCreate();
  return atan_op;
}

const TfLiteRegistration AtanOpRegistration() {
  static const TfLiteRegistration atan_op {
    .registration_external = AtanOpRegistrationExternal();
  };
  return atan_op;
}
      

初始化 OpResolver 时,请将自定义操作添加到解析器中(如需查看示例,请参阅下文)。这将在 Tensorflow Lite 中注册运算符,以便 TensorFlow Lite 能够使用新实现。请注意,TfLiteRegistration 中的最后两个参数对应于您为自定义操作定义的 AtanPrepareAtanEval 函数。如果您使用 AtanInitAtanFree 函数分别初始化操作中使用的变量和释放空间,它们会添加到 TfLiteRegistration 的前两个参数;在本示例中,这些参数设置为 nullptr

向内核库注册运算符

现在,我们需要向内核库注册操作器。这是通过 OpResolver 实现的。解释器将在后台加载一个内核库,该内核库将被分配用于执行模型中的每个运算符。虽然默认库仅包含内置内核,但可以使用自定义库操作运算符替换/扩充该库。

用于将运营商代码和名称转换为实际代码的 OpResolver 类定义如下:

class OpResolver {
 public:
  virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0;
  virtual TfLiteRegistration* FindOp(const char* op) const = 0;
  ...
};

请注意,为了实现向后兼容性,此类使用较旧的具体类型 TfLiteRegistration,而不是不透明类型 TfLiteRegistrationExternal,但 TfLiteRegistration 结构体包含一个类型为 TfLiteRegistrationExternal*registration_external 字段。

MutableOpResolverBuiltinOpResolver 类派生自 OpResolver

class MutableOpResolver : public OpResolver {
 public:
  MutableOpResolver();  // Constructs an initially empty op resolver.
  void AddBuiltin(tflite::BuiltinOperator op, const TfLiteRegistration* registration) = 0;
  void AddCustom(const char* op, const TfLiteRegistration* registration) = 0;
  void AddAll(const MutableOpResolver& other);
  ...
};

class BuiltinOpResolver : public MutableOpResolver {
 public:
  BuiltinOpResolver();  // Constructs an op resolver with all the builtin ops.
};

如需常规使用(无自定义操作),您需要使用 BuiltinOpResolver 并写入:

tflite::ops::builtin::BuiltinOpResolver resolver;

如需添加上面创建的自定义操作,您可以改用 MutableOpResolver,然后调用 AddCustom(在将解析器传递给 InterpreterBuilder 之前):

tflite::ops::builtin::MutableOpResolver resolver;
resolver.AddAll(tflite::ops::builtin::BuiltinOpResolver());
resolver.AddCustom("Atan", AtanOpRegistration());

如果认为该内置操作太大,系统可能会根据指定的操作子集(可能只有指定模型中包含的操作)以代码方式生成新的 OpResolver。这相当于 TensorFlow 的选择性注册(其简单版本位于 tools 目录中)。

如果您想在 Java 中定义自定义运算符,目前需要构建自己的自定义 JNI 层,并使用此 jni 代码编译自己的 AAR。同样,如果您希望定义这些可在 Python 中使用的运算符,可以在 Python 封装容器代码中进行注册。

请注意,可以遵循与上述类似的过程来支持一组操作,而不是单个运算符。只需根据需要添加任意数量的 AddCustom 运算符即可。此外,MutableOpResolver 还允许您使用 AddBuiltin 替换内置项的实现。

测试运营商并分析其性能

如需使用 TensorFlow Lite 基准工具对操作进行性能分析,您可以使用适用于 TensorFlow Lite 的基准模型工具。为便于测试,您可以在 register.cc 中添加适当的 AddCustom 调用(如上所示),让 TensorFlow Lite 的本地 build 感知您的自定义操作

最佳实践

  1. 谨慎优化内存分配和取消分配。在 Prepare 中分配内存比在 Invoke 中更高效,在循环之前分配内存比在每次迭代中更好。使用临时张量数据,而不是自行分配数据(请参阅第 2 项)。使用指针/引用,而不是尽可能地复制。

  2. 如果某个数据结构在整个操作过程中持续存在,我们建议您使用临时张量预先分配内存。您可能需要使用 OpData 结构体来引用其他函数中的张量索引。请参阅卷积内核中的示例。示例代码段如下所示。

    struct MyOpData {
      int temp_tensor_index;
      ...
    };
    
    void* Init(TfLiteOpaqueContext* context,
        const char* buffer, size_t length) {
      auto* op_data = new MyOpData{};
      ...
      return op_data;
    }
    void Free(TfLiteOpaqueContext* context, void* buffer) {
      ...
      delete reinterpret_cast<MyOpData*>(buffer);
    }
    TfLiteStatus Prepare(TfLiteOpaqueContext* context,
                         TfLiteOpaqueNode* node) {
      ...
      auto* op_data =
          reinterpret_cast<MyOpData*>(TfLiteOpaqueNodeGetUserData(node));
      const int num_temporaries = 1;
      int temporary_tensor_indices[num_temporaries];
      TfLiteOpaqueTensorBuilder* builder = TfLiteOpaqueTensorBuilderCreate();
      TfLiteOpaqueTensorBuilderSetType(builder, kTfLiteFloat32);
      TfLiteOpaqueTensorBuilderSetAllocationType(builder, kTfLiteArenaRw);
      TfLiteOpaqueContextAddTensor(context, builder,
          &temporary_tensor_indices[0]);
      TfLiteOpaqueTensorBuilderDelete(builder);
      TfLiteOpaqueNodeSetTemporaries(node, temporary_tensor_indices,
          num_temporaries);
      op_data->temp_tensor_index = temporary_tensor_indices[0];
      ...
      return kTfLiteOk;
    }
    TfLiteStatus Invoke(TfLiteOpaqueContext* context,
                        TfLiteOpaqueNode* node) {
      ...
      auto* op_data = reinterpret_cast<MyOpData*>(
          TfLiteOpaqueNodeGetUserData(node));
      TfLiteOpaqueTensor* temp_tensor =
          TfLiteOpaqueContextGetOpaqueTensor(context,
              op_data->temp_tensor_index);
      TF_LITE_OPAQUE_ENSURE(context,
          TfLiteTensorType(temp_tensor) == kTfLiteFloat32);
      TF_LITE_OPAQUE_ENSURE(context,
          TfLiteTensorGetAllocationType(temp_Tensor) == kTfLiteArenaRw);
      void *temp_data = TfLiteTensorData(temp_tensor);
      TF_LITE_OPAQUE_ENSURE(context, temp_data != nullptr);
      ...
      return kTfLiteOk;
    }
    
  3. 如果不会浪费太多内存,最好使用静态大小的数组(或 Resize 中预先分配的 std::vector),而不是在每次执行迭代时使用动态分配的 std::vector

  4. 避免实例化尚不存在的标准库容器模板,因为它们会影响二进制文件的大小。例如,如果您的操作需要使用其他内核中不存在的 std::map,则将 std::vector 用于直接索引映射可能会有效,同时保持较小的二进制文件大小。了解其他内核使用什么来获得数据洞见(或询问)。

  5. 检查指向 malloc 返回的内存的指针。如果此指针为 nullptr,则不应使用该指针执行任何操作。如果您在函数中 malloc 并且退出了错误,请在退出之前取消分配内存。

  6. 请使用 TF_LITE_OPAQUE_ENSURE(context, condition) 检查特定条件。使用 TF_LITE_OPAQUE_ENSURE 时,您的代码不得使内存挂起,也就是说,在分配任何会泄漏的资源之前,就应该使用这些宏。