自定义运算符

由于 LiteRT 内置运算符库仅支持有限数量的 TensorFlow 运算符,因此并非所有模型都可以转换。如需了解详情,请参阅运营商兼容性

为了实现转换,用户可以在 LiteRT 中提供不受支持的 TensorFlow 运算符的自定义实现(称为自定义运算符)。如果您希望将一系列不受支持(或受支持)的 TensorFlow 运算符合并为单个融合的优化自定义运算符,请参阅运算符融合

使用自定义运算符包含四个步骤。

我们来逐步了解一个端到端示例,该示例展示了如何运行具有自定义运算符 tf.atan(命名为 Atan,请参阅创建 TensorFlow 模型)的模型,该运算符在 TensorFlow 中受支持,但在 LiteRT 中不受支持。

TensorFlow Text 运算符是自定义运算符的一个示例。如需查看代码示例,请参阅将 TF Text 转换为 LiteRT 教程。

示例:自定义 Atan 运算符

我们来演示一下如何支持 LiteRT 没有的 TensorFlow 运算符。假设我们使用 Atan 运算符,并且正在为函数 y = atan(x + offset) 构建一个非常简单的模型,其中 offset 是可训练的。

创建 TensorFlow 模型

以下代码段用于训练简单的 TensorFlow 模型。此模型仅包含一个名为 Atan 的自定义运算符,该运算符是一个函数 y = atan(x + offset),其中 offset 是可训练的。

import tensorflow as tf

# Define training dataset and variables
x = [-8, 0.5, 2, 2.2, 201]
y = [-1.4288993, 0.98279375, 1.2490457, 1.2679114, 1.5658458]
offset = tf.Variable(0.0)

# Define a simple model which just contains a custom operator named `Atan`
@tf.function(input_signature=[tf.TensorSpec.from_tensor(tf.constant(x))])
def atan(x):
  return tf.atan(x + offset, name="Atan")

# Train model
optimizer = tf.optimizers.Adam(0.01)
def train(x, y):
    with tf.GradientTape() as t:
      predicted_y = atan(x)
      loss = tf.reduce_sum(tf.square(predicted_y - y))
    grads = t.gradient(loss, [offset])
    optimizer.apply_gradients(zip(grads, [offset]))

for i in range(1000):
    train(x, y)

print("The actual offset is: 1.0")
print("The predicted offset is:", offset.numpy())
The actual offset is: 1.0
The predicted offset is: 0.99999905

此时,如果您尝试使用默认转换器标志生成 LiteRT 模型,将会收到以下错误消息:

Error:
error: 'tf.Atan' op is neither a custom op nor a flex op.

转换为 LiteRT 模型

通过设置转换器属性 allow_custom_ops(如下所示),创建具有自定义运算符的 LiteRT 模型:

converter = tf.lite.TFLiteConverter.from_concrete_functions([atan.get_concrete_function()], atan)
converter.allow_custom_ops = True
tflite_model = converter.convert()

此时,如果您使用默认解释器运行该脚本,例如使用以下命令:

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

您仍会收到以下错误:

Encountered unresolved custom op: Atan.

创建并注册操作员。

#include "third_party/tensorflow/lite/c/c_api.h"
#include "third_party/tensorflow/lite/c/c_api_opaque.h"

LiteRT 自定义运算符是使用简单的纯 C API 定义的,该 API 由不透明类型 (TfLiteOperator) 和相关函数组成。

TfLiteOperator 是不透明类型:

typedef struct TfLiteOperator TfLiteOperator;

TfLiteOperator 存储了运算符的身份和实现。 (请注意,运算符不同于其操作数,后者存储在调用该运算符的节点的 LiteRT 图节点中。)

此类实例通过调用 TfLiteOperatorCreate 构建,可通过调用 TfLiteOperatorDelete 销毁。

运算符的身份是通过构造函数 TfLiteOperatorCreate 的参数设置的:

TfLiteOperator*
TfLiteOperatorCreate(
    TfLiteBuiltinOperator builtin_code,  // Normally `TfLiteBuiltinCustom`.
    const char* custom_name,  // The name of the custom op.
    int version  // Normally `1` for the first version of a custom op.
);

运营商实现可以定义具有以下签名的“方法”。 所有这些方法都是可选的,但为了成功评估运算符,运算符实现需要定义并设置(使用 setter 函数)至少 PrepareInvoke 方法。

// Initializes the op from serialized data.
void* Init(TfLiteOpaqueContext* context, const char* buffer, size_t length);

// Deallocates the op.
// The pointer `buffer` is the data previously returned by an Init invocation.
void Free(TfLiteOpaqueContext* context, void* buffer);

// Called when the inputs that this node depends on have been resized.
TfLiteStatus Prepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node);

// Called when the node is executed. (Should read node inputs and write to
// node outputs).
TfLiteStatus Invoke(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node);

// Retrieves the async kernel.
TfLiteAsyncKernel AsyncKernel(TfLiteOpaqueContext* context,
                              TfLiteOpaqueNode* node);

您的运算实现中的函数名称(或命名空间前缀,对于 C++)不必与上述代码段中的函数名称匹配,因为 TF Lite 自定义运算 API 只会使用它们的地址。事实上,我们建议您在匿名命名空间中或作为静态函数声明它们。

不过,最好在这些函数名称中添加您的运算符名称作为命名空间或前缀:

C++

namespace my_namespace::my_custom_op {
  void* Init(TfLiteOpaqueContext* context,
             const char* buffer, size_t length) { ... }
  // ... plus definitions of Free, Prepare, and Invoke ...
}
      

C

void* MyCustomOpInit(TfLiteOpaqueContext* context,
                     const char* buffer, size_t length) { ... }
// ... plus definitions of MyCustomOpFree, MyCustomOpPrepare, and
// MyCustomOpInvoke.
      

由于这是一个 C API,因此这些“方法”在 TfLiteOperator 类型中以 C 函数指针的形式实现,通过将实现函数的地址传递给相应的 setter 函数 TfLiteOperatorSetMethodName 来设置:

void TfLiteOperatorSetInit(
    TfLiteOperator* operator,
    void* (*init)(TfLiteOpaqueContext* context, const char* buffer,
                  size_t length));
void TfLiteOperatorSetFree(
    TfLiteOperator* operator,
    void (*free)(TfLiteOpaqueContext* context, void* data));
void TfLiteOperatorSetPrepare(
    TfLiteOperator* operator,
    TfLiteStatus (*prepare)(TfLiteOpaqueContext* context,
                            TfLiteOpaqueNode* node));
void TfLiteOperatorSetInvoke(
    TfLiteOperator* operator,
    TfLiteStatus (*invoke)(TfLiteOpaqueContext* context,
                           TfLiteOpaqueNode* node));
void TfLiteOperatorSetAsyncKernel(
    TfLiteOperator* operator,
    struct TfLiteAsyncKernel* (*async_kernel)(TfLiteOpaqueContext* context,
                                              TfLiteOpaqueNode* node));

如需详细了解 TfLiteContextTfLiteNode,请参阅 common.hTfLiteContext 提供错误报告功能,并允许访问全局对象(包括所有张量)。TfLiteNode 允许运算符实现访问其输入和输出。

当解释器加载模型时,它会针对图中的每个节点调用一次 Init() 方法。如果 op 在图中多次使用,则给定的 Init() 将被多次调用。对于自定义操作,系统会提供一个配置缓冲区,其中包含将参数名称映射到其值的 flexbuffer。内置操作的缓冲区为空,因为解释器已解析操作参数。需要状态的内核实现应在此处初始化状态,并将所有权转移给调用方。对于每次 Init() 调用,都会有相应的 Free() 调用,从而允许实现处置可能在 Init() 中分配的缓冲区。

每当输入张量的大小调整时,解释器都会遍历图,通知实现更改。这样一来,它们就有机会调整内部缓冲区的大小、检查输入形状和类型的有效性,以及重新计算输出形状。所有这些都是通过 Prepare() 方法完成的,实现可以使用 TfLiteOpaqueNodeGetUserData(node) 访问其状态。

最后,每次运行推理时,解释器都会遍历图,调用 Invoke() 方法,此时状态也可作为 TfLiteOpaqueNodeGetUserData(node) 使用。

通过定义这些“方法”函数,然后定义一个函数(该函数通过调用 TfLiteOperatorCreate 并随后调用相关的 setter 方法来返回 TfLiteOperator 的实例),即可实现自定义操作:

C++

namespace my_namespace::my_custom_op {
  namespace {
    void* Init(TfLiteOpaqueContext* context,
               const char* buffer, size_t length) { ... }
    void Free(TfLiteOpaqueContext* context, void* buffer) { ... }
    TfLiteStatus Prepare(TfLiteOpaqueContext* context,
                         TfLiteOpaqueNode* node) { ... }
    TfLiteStatus Invoke(TfLiteOpaqueContext* context,
                        TfLiteOpaqueNode* node) {... }
  };

  const TfLiteOperator* MyCustomOperator() {
    // Singleton instance, intentionally never destroyed.
    static const TfLiteOperator* my_custom_op = ()[] {
        TfLiteOperator* r =
            TfLiteOperatorCreate(
                kTfLiteBuiltinCustom, "MyCustomOp", /*version=*/ 1);
        TfLiteOperatorSetInit(r, Init);
        TfLiteOperatorSetFree(r, Free);
        TfLiteOperatorSetPrepare(r, Prepare);
        TfLiteOperatorSetInvoke(r, Eval);
        return r;
      };
    return my_custom_op;
  }
}  // namespace my_namespace
      

C

static void* MyCustomOpInit(TfLiteOpaqueContext* context, const char* buffer,
                     size_t length) { ... }
static void MyCustomOpFree(TfLiteOpaqueContext* context, void* buffer) { ... }
static TfLiteStatus MyCustomOpPrepare(TfLiteOpaqueContext* context,
                                      TfLiteOpaqueNode* node) { ... }
static TfLiteStatus MyCustomOpInvoke(TfLiteOpaqueContext* context,
                                     TfLiteOpaqueNode* node) {... }

static TfLiteOperator* MyCustomOpCreate() {
  const TfLiteOperator* r =
      TfLiteOperatorCreate(
          kTfLiteBuiltinCustom, "MyCustomOp", /*version=*/ 1);
  TfLiteOperatorSetInit(r, MyCustomOpInit);
  TfLiteOperatorSetFree(r, MyCustomOpFree);
  TfLiteOperatorSetPrepare(r, MyCustomOpPrepare);
  TfLiteOperatorSetInvoke(r, MyCustomOpEval);
  return r;
}

const TfLiteOperator* MyCustomOperator() {
  // Singleton instance, intentionally never destroyed.
  static const TfLiteOperator* my_custom_op = MyCustomOpCreate();
  return my_custom_op;
}
      

请注意,注册不是自动进行的,您应显式调用 MyCustomOperator 函数(详见下文)。虽然标准 BuiltinOpResolver(可从 :builtin_ops 目标获取)负责注册内置函数,但自定义操作必须收集在单独的自定义库中。

在 LiteRT 运行时中定义内核

如需在 LiteRT 中使用该操作,我们只需定义两个函数(PrepareEval),以及第三个用于构建 TfLiteOperator 的函数:

C++

namespace atan_op {
  namespace {
    TfLiteStatus AtanPrepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) {
      TF_LITE_OPAQUE_ENSURE_EQ(context, TfLiteOpaqueNodeNumInputs(node), 1);
      TF_LITE_OPAQUE_ENSURE_EQ(context, TfLiteOpaqueNodeNumOutputs(node), 1);

      const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0);
      TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0);

      int num_dims = TfLiteOpaqueTensorNumDimensions(input);

      TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims);
      for (int i=0; i < num_dims; ++i) {
        output_size->data[i] = input->dims->data[i];
      }

      return TfLiteOpaqueContextResizeTensor(context, output, output_size);
    }

    TfLiteStatus AtanEval(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) {
      const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0);
      TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0);

      float* input_data = static_cast<float*>(TfLiteOpaqueTensorData(input));
      float* output_data = static_cast<float*>(TfLiteOpaqueTensorData(output));

      size_t count = 1;
      int num_dims = TfLiteOpaqueTensorNumDimensions(input);
      for (int i = 0; i < num_dims; ++i) {
        count *= input->dims->data[i];
      }

      for (size_t i = 0; i < count; ++i) {
        output_data[i] = atan(input_data[i]);
      }
      return kTfLiteOk;
    }
  }  // anonymous namespace

  const TfLiteOperator* AtanOperator() {
    // Singleton instance, intentionally never destroyed.
    static const TfLiteOperator* atan_op = ()[] {
        auto* r = TfLiteOperatorCreate(
            kTfLiteBuiltinCustom, "ATAN", /*version=*/ 1);
        TfLiteOperatorSetPrepare(r, Prepare);
        TfLiteOperatorSetInvoke(r, Eval);
        return r;
      };
    return atan_op;
  }
}  // namespace atan_op
      

C

static TfLiteStatus AtanPrepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) {
  TF_LITE_OPAQUE_ENSURE_EQ(context, TfLiteOpaqueNodeNumInputs(node), 1);
  TF_LITE_OPAQUE_ENSURE_EQ(context, TfLiteOpaqueNodeNumOutputs(node), 1);

  const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0);
  TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0);

  int num_dims = TfLiteOpaqueTensorNumDimensions(input);

  TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims);
  for (int i = 0; i < num_dims; ++i) {
    output_size->data[i] = input->dims->data[i];
  }

  return TfLiteOpaqueContextResizeTensor(context, output, output_size);
}

static TfLiteStatus AtanEval(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) {
  const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0);
  TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0);

  float* input_data = static_cast<float*>(TfLiteOpaqueTensorData(input));
  float* output_data = static_cast<float*>(TfLiteOpaqueTensorData(output));

  size_t count = 1;
  int num_dims = TfLiteOpaqueTensorNumDimensions(input);
  for (int i = 0; i < num_dims; ++i) {
    count *= input->dims->data[i];
  }

  for (size_t i = 0; i < count; ++i) {
    output_data[i] = atan(input_data[i]);
  }
  return kTfLiteOk;
}

static const TfLiteOperator* AtanOpCreate() {
  TfLiteOperator* r = TfLiteOperatorCreate(
          kTfLiteBuiltinCustom, "ATAN", /*version=*/ 1);
  TfLiteOperatorSetPrepare(r, Prepare);
  TfLiteOperatorSetInvoke(r, Eval);
  return r;
}

const TfLiteOperator* AtanOperator() {
  // Singleton instance, intentionally never destroyed.
  static const TfLiteOperator* atan_op = AtanOpCreate();
  return atan_op;
}
      

初始化 OpResolver 时,将自定义操作添加到解析器中(请参阅下面的示例)。这会将运算符注册到 LiteRT,以便 LiteRT 可以使用新的实现。

向内核库注册运算符

现在,我们需要向内核库注册该运算符。这是通过 OpResolver 完成的。在幕后,解释器会加载一个内核库,该库将分配给模型中的每个运算符来执行。虽然默认库仅包含内置内核,但您可以使用自定义库操作替换/扩充它。

OpResolver 类用于将运营商代码和名称转换为实际代码,其定义如下:

class OpResolver {
 public:
  virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0;
  virtual TfLiteRegistration* FindOp(const char* op) const = 0;
  ...
};

请注意,为了实现向后兼容性,此类使用较旧的具体类型 TfLiteRegistration 而不是不透明类型 TfLiteOperator,但 TfLiteRegistration 结构包含类型为 TfLiteOperator*registration_external 字段。

MutableOpResolverBuiltinOpResolver 类派生自 OpResolver

class MutableOpResolver : public OpResolver {
 public:
  MutableOpResolver();  // Constructs an initially empty op resolver.
  void AddAll(const MutableOpResolver& other);
  ...
};

class BuiltinOpResolver : public MutableOpResolver {
 public:
  BuiltinOpResolver();  // Constructs an op resolver with all the builtin ops.
};

常规使用(不含自定义操作)需要使用 BuiltinOpResolver 并编写以下内容:

tflite::ops::builtin::BuiltinOpResolver resolver;

如需添加上面创建的自定义操作,您可以改用 MutableOpResolver,并调用 tflite::AddOp(在将解析器传递给 InterpreterBuilder 之前):

tflite::ops::builtin::MutableOpResolver resolver;
resolver.AddAll(tflite::ops::builtin::BuiltinOpResolver());
tflite::AddOp(&resolver, AtanOpRegistration());

如果内置操作集过大,可以基于给定的操作子集(可能仅是给定模型中包含的操作)代码生成新的 OpResolver。这相当于 TensorFlow 的选择性注册(tools 目录中提供了一个简单版本)。

如果您想在 Java 中定义自己的自定义运算符,目前需要构建自己的自定义 JNI 层,并在此 JNI 代码中编译自己的 AAR。同样,如果您希望在 Python 中定义这些运算符,可以将注册信息放在 Python 封装容器代码中。

请注意,如果需要支持一组操作而非单个运算符,可以按照上述类似流程进行操作。只需根据需要添加任意数量的 AddCustom 运算符即可。此外,MutableOpResolver 还允许您使用 AddBuiltin 替换内置函数的实现。

测试和分析您的运算符

如需使用 LiteRT 基准测试工具剖析您的操作,您可以将 LiteRT 的基准测试模型工具用于此目的。出于测试目的,您可以通过向 register.cc 添加相应的 AddCustom 调用(如上所示)来使 LiteRT 的本地 build 识别您的自定义操作。

最佳做法

  1. 谨慎优化内存分配和取消分配。在 Prepare 中分配内存比在 Invoke 中分配内存更高效,并且在循环之前分配内存比在每次迭代中分配内存更好。使用临时张量数据,而不是自行分配内存 (malloc)(请参阅第 2 项)。尽可能使用指针/引用,而不是复制。

  2. 如果数据结构在整个操作期间都会保持不变,我们建议使用临时张量预先分配内存。您可能需要使用 OpData 结构体来引用其他函数中的张量索引。请参阅卷积的内核中的示例。 以下是一个代码段示例。

    struct MyOpData {
      int temp_tensor_index;
      ...
    };
    
    void* Init(TfLiteOpaqueContext* context,
        const char* buffer, size_t length) {
      auto* op_data = new MyOpData{};
      ...
      return op_data;
    }
    void Free(TfLiteOpaqueContext* context, void* buffer) {
      ...
      delete reinterpret_cast<MyOpData*>(buffer);
    }
    TfLiteStatus Prepare(TfLiteOpaqueContext* context,
                         TfLiteOpaqueNode* node) {
      ...
      auto* op_data =
          reinterpret_cast<MyOpData*>(TfLiteOpaqueNodeGetUserData(node));
      const int num_temporaries = 1;
      int temporary_tensor_indices[num_temporaries];
      TfLiteOpaqueTensorBuilder* builder = TfLiteOpaqueTensorBuilderCreate();
      TfLiteOpaqueTensorBuilderSetType(builder, kTfLiteFloat32);
      TfLiteOpaqueTensorBuilderSetAllocationType(builder, kTfLiteArenaRw);
      TfLiteOpaqueContextAddTensor(context, builder,
          &temporary_tensor_indices[0]);
      TfLiteOpaqueTensorBuilderDelete(builder);
      TfLiteOpaqueNodeSetTemporaries(node, temporary_tensor_indices,
          num_temporaries);
      op_data->temp_tensor_index = temporary_tensor_indices[0];
      ...
      return kTfLiteOk;
    }
    TfLiteStatus Invoke(TfLiteOpaqueContext* context,
                        TfLiteOpaqueNode* node) {
      ...
      auto* op_data = reinterpret_cast<MyOpData*>(
          TfLiteOpaqueNodeGetUserData(node));
      TfLiteOpaqueTensor* temp_tensor =
          TfLiteOpaqueContextGetOpaqueTensor(context,
              op_data->temp_tensor_index);
      TF_LITE_OPAQUE_ENSURE(context,
          TfLiteTensorType(temp_tensor) == kTfLiteFloat32);
      TF_LITE_OPAQUE_ENSURE(context,
          TfLiteTensorGetAllocationType(temp_Tensor) == kTfLiteArenaRw);
      void *temp_data = TfLiteTensorData(temp_tensor);
      TF_LITE_OPAQUE_ENSURE(context, temp_data != nullptr);
      ...
      return kTfLiteOk;
    }
    
  3. 如果不会浪费太多内存,建议使用静态固定大小的数组(或 Resize 中的预分配 std::vector),而不是在每次执行迭代时使用动态分配的 std::vector

  4. 避免实例化尚不存在的标准库容器模板,因为它们会影响二进制文件大小。例如,如果您的操作需要其他内核中不存在的 std::map,则使用具有直接索引映射的 std::vector 可以在保持二进制文件大小较小的情况下正常运行。查看其他内核使用的内容,以获取相关信息(或提出问题)。

  5. 检查 malloc 返回的内存指针。如果此指针为 nullptr,则不应使用该指针执行任何操作。如果您在函数中使用了 malloc 并且有错误退出,请在退出之前释放内存。

  6. 使用 TF_LITE_OPAQUE_ENSURE(context, condition) 检查特定条件。使用 TF_LITE_OPAQUE_ENSURE 时,您的代码不得留下悬挂内存,也就是说,这些宏应在分配任何会泄漏的资源之前使用。