TensorFlow Lite Task 库提供了预构建的 C++、Android 和 iOS API,以抽象化基础架构为基础 TensorFlow。您可以扩展 Task API 基础架构以构建自定义 API 如果您的模型不受现有 Task 库支持,则会发生该错误。
概览
Task API 基础架构具有两层结构:底部 C++ 层 封装 TFLite 运行时和顶层 Java/ObjC 层, 通过 JNI 或封装容器与 C++ 层进行通信。
只用 C++ 实现所有 TensorFlow 逻辑可以最大限度地降低成本,最大化 推理性能并简化跨平台的整体工作流。
要创建 Task 类,请扩展 BaseTaskApi 用于在 TFLite 模型接口和 Task API 之间提供转换逻辑 接口,然后使用 Java/ObjC 实用程序创建相应的 API。包含 所有 TensorFlow 详细信息已隐藏,您可以在应用中部署 TFLite 模型 无需任何机器学习知识。
TensorFlow Lite 提供了一些预构建的 API,适用于最受欢迎的 视觉和 NLP 任务。您可以构建 您自己的 API,用于使用 Task API 基础架构执行其他任务。
使用 Task API 基础架构构建您自己的 API
C++ API
所有 TFLite 详细信息在 C++ API 中实现。通过以下操作创建 API 对象: 使用其中一个工厂函数,并通过调用函数获取模型结果 该接口所定义的变量。
用法示例
下面是一个使用 C++
BertQuestionAnswerer
用于
MobileBert。
char kBertModelPath[] = "path/to/model.tflite";
// Create the API from a model file
std::unique_ptr<BertQuestionAnswerer> question_answerer =
BertQuestionAnswerer::CreateFromFile(kBertModelPath);
char kContext[] = ...; // context of a question to be answered
char kQuestion[] = ...; // question to be answered
// ask a question
std::vector<QaAnswer> answers = question_answerer.Answer(kContext, kQuestion);
// answers[0].text is the best answer
构建 API
要构建 API 对象,您必须通过
BaseTaskApi
确定 API I/O - 您的 API 应公开类似的输入/输出 不同平台上的视频广告例如
BertQuestionAnswerer
接受两个字符串(std::string& context, std::string& question)
作为输入,并输出 可能答案和概率的向量,格式为std::vector<QaAnswer>
。这个 方法是在BaseTaskApi
的 模板参数。 在指定模板参数的情况下,BaseTaskApi::Infer
函数将具有正确的输入/输出类型。此函数可以 但最好将其封装在 模型专用函数,在本示例中为BertQuestionAnswerer::Answer
。class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Model specific function delegating calls to BaseTaskApi::Infer std::vector<QaAnswer> Answer(const std::string& context, const std::string& question) { return Infer(context, question).value(); } }
提供 API I/O 和 model - 指定了输入和输出类型后,子类还需要 实现类型化函数
BaseTaskApi::Preprocess
和BaseTaskApi::Postprocess
。 这两个函数提供了 输入 和 输出 。FlatBuffer
子类负责将 从 API I/O 到 I/O 张量。查看完整实现 示例BertQuestionAnswerer
。class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Convert API input into tensors absl::Status BertQuestionAnswerer::Preprocess( const std::vector<TfLiteTensor*>& input_tensors, // input tensors of the model const std::string& context, const std::string& query // InputType of the API ) { // Perform tokenization on input strings ... // Populate IDs, Masks and SegmentIDs to corresponding input tensors PopulateTensor(input_ids, input_tensors[0]); PopulateTensor(input_mask, input_tensors[1]); PopulateTensor(segment_ids, input_tensors[2]); return absl::OkStatus(); } // Convert output tensors into API output StatusOr<std::vector<QaAnswer>> // OutputType BertQuestionAnswerer::Postprocess( const std::vector<const TfLiteTensor*>& output_tensors, // output tensors of the model ) { // Get start/end logits of prediction result from output tensors std::vector<float> end_logits; std::vector<float> start_logits; // output_tensors[0]: end_logits FLOAT[1, 384] PopulateVector(output_tensors[0], &end_logits); // output_tensors[1]: start_logits FLOAT[1, 384] PopulateVector(output_tensors[1], &start_logits); ... std::vector<QaAnswer::Pos> orig_results; // Look up the indices from vocabulary file and build results ... return orig_results; } }
创建 API 的工厂函数 - 模型文件和
OpResolver
来初始化tflite::Interpreter
。TaskAPIFactory
提供用于创建 BaseTaskApi 实例的实用函数。您还必须提供与模型相关联的所有文件。例如
BertQuestionAnswerer
还可以包含用于其标记生成器的附加文件, 词汇表。class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Factory function to create the API instance StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateBertQuestionAnswerer( const std::string& path_to_model, // model to passed to TaskApiFactory const std::string& path_to_vocab // additional model specific files ) { // Creates an API object by calling one of the utils from TaskAPIFactory std::unique_ptr<BertQuestionAnswerer> api_to_init; ASSIGN_OR_RETURN( api_to_init, core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( path_to_model, absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), kNumLiteThreads)); // Perform additional model specific initializations // In this case building a vocabulary vector from the vocab file. api_to_init->InitializeVocab(path_to_vocab); return api_to_init; } }
Android API
通过定义 Java/Kotlin 接口并委托逻辑来创建 Android API 通过 JNI 传递到 C++ 层。Android API 需要先构建原生 API。
用法示例
下面是一个使用 Java 的示例
BertQuestionAnswerer
用于
MobileBert。
String BERT_MODEL_FILE = "path/to/model.tflite";
String VOCAB_FILE = "path/to/vocab.txt";
// Create the API from a model file and vocabulary file
BertQuestionAnswerer bertQuestionAnswerer =
BertQuestionAnswerer.createBertQuestionAnswerer(
ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE, VOCAB_FILE);
String CONTEXT = ...; // context of a question to be answered
String QUESTION = ...; // question to be answered
// ask a question
List<QaAnswer> answers = bertQuestionAnswerer.answer(CONTEXT, QUESTION);
// answers.get(0).text is the best answer
构建 API
与原生 API 类似,要构建 API 对象,客户端需要提供
来扩展
BaseTaskApi
、
它为所有 Java Task API 提供 JNI 处理功能。
确定 API I/O - 这通常会镜像原生接口。例如
BertQuestionAnswerer
将(String context, String question)
作为输入 并输出List<QaAnswer>
。该实现会调用专用原生代码 函数,不同之处在于它有一个附加参数long nativeHandle
,它是从 C++ 返回的指针。class BertQuestionAnswerer extends BaseTaskApi { public List<QaAnswer> answer(String context, String question) { return answerNative(getNativeHandle(), context, question); } private static native List<QaAnswer> answerNative( long nativeHandle, // C++ pointer String context, String question // API I/O ); }
创建 API 的工厂函数 - 这也镜像原生工厂 函数,除了 Android 工厂函数外,还需要将
Context
文件访问权限。该实现会调用TaskJniUtils
来构建相应的 C++ API 对象,并将其指针传递到BaseTaskApi
构造函数。class BertQuestionAnswerer extends BaseTaskApi { private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "bert_question_answerer_jni"; // Extending super constructor by providing the // native handle(pointer of corresponding C++ API object) private BertQuestionAnswerer(long nativeHandle) { super(nativeHandle); } public static BertQuestionAnswerer createBertQuestionAnswerer( Context context, // Accessing Android files String pathToModel, String pathToVocab) { return new BertQuestionAnswerer( // The util first try loads the JNI module with name // BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, then opens two files, // converts them into ByteBuffer, finally ::initJniWithBertByteBuffers // is called with the buffer for a C++ API object pointer TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( context, BertQuestionAnswerer::initJniWithBertByteBuffers, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, pathToModel, pathToVocab)); } // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer. // returns C++ API object pointer casted to long private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers); }
为原生函数实现 JNI 模块 - 所有 Java 原生方法 通过从 JNI 调用相应的原生函数来实现 模块。工厂函数会创建一个原生 API 对象并返回 并将其指针作为指向 Java 的长类型进行传递。在以后对 Java API 的调用中,长 类型指针将传递回 JNI 并转换回原生 API 对象。 然后,原生 API 结果将转换回 Java 结果。
例如,这是 bert_question_answerer_jni 。
// Implements BertQuestionAnswerer::initJniWithBertByteBuffers extern "C" JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers( JNIEnv* env, jclass thiz, jobjectArray model_buffers) { // Convert Java ByteBuffer object into a buffer that can be read by native factory functions absl::string_view model = GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0)); // Creates the native API object absl::StatusOr<std::unique_ptr<QuestionAnswerer>> status = BertQuestionAnswerer::CreateFromBuffer( model.data(), model.size()); if (status.ok()) { // converts the object pointer to jlong and return to Java. return reinterpret_cast<jlong>(status->release()); } else { return kInvalidPointer; } } // Implements BertQuestionAnswerer::answerNative extern "C" JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative( JNIEnv* env, jclass thiz, jlong native_handle, jstring context, jstring question) { // Convert long to native API object pointer QuestionAnswerer* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle); // Calls the native API std::vector<QaAnswer> results = question_answerer->Answer(JStringToString(env, context), JStringToString(env, question)); // Converts native result(std::vector<QaAnswer>) to Java result(List<QaAnswerer>) jclass qa_answer_class = env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer"); jmethodID qa_answer_ctor = env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V"); return ConvertVectorToArrayList<QaAnswer>( env, results, [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) { jstring text = env->NewStringUTF(ans.text.data()); jobject qa_answer = env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start, ans.pos.end, ans.pos.logit); env->DeleteLocalRef(text); return qa_answer; }); } // Implements BaseTaskApi::deinitJni by delete the native object extern "C" JNIEXPORT void JNICALL Java_task_core_BaseTaskApi_deinitJni( JNIEnv* env, jobject thiz, jlong native_handle) { delete reinterpret_cast<QuestionAnswerer*>(native_handle); }
iOS API
通过将原生 API 对象封装到 ObjC API 对象中来创建 iOS API。通过 可以在 ObjC 或 Swift 中使用。iOS API 需要使用 优先构建原生 API
用法示例
下面是一个使用 ObjC 的示例
TFLBertQuestionAnswerer
对于 MobileBert
。
static let mobileBertModelPath = "path/to/model.tflite";
// Create the API from a model file and vocabulary file
let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
modelPath: mobileBertModelPath)
static let context = ...; // context of a question to be answered
static let question = ...; // question to be answered
// ask a question
let answers = mobileBertAnswerer.answer(
context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
// answers.[0].text is the best answer
构建 API
iOS API 是在原生 API 基础之上构建的简单 ObjC 封装容器。通过以下方式构建 API: 按以下步骤操作:
定义 ObjC 封装容器 - 定义 ObjC 类并委托 相应的原生 API 对象实现。请注意, 由于 Swift 无法 与 C++ 互操作。
- .h 文件
@interface TFLBertQuestionAnswerer : NSObject // Delegate calls to the native BertQuestionAnswerer::CreateBertQuestionAnswerer + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath vocabPath:(NSString*)vocabPath NS_SWIFT_NAME(mobilebertQuestionAnswerer(modelPath:vocabPath:)); // Delegate calls to the native BertQuestionAnswerer::Answer - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context question:(NSString*)question NS_SWIFT_NAME(answer(context:question:)); }
- .mm 文件
using BertQuestionAnswererCPP = ::tflite::task::text::BertQuestionAnswerer; @implementation TFLBertQuestionAnswerer { // define an iVar for the native API object std::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer; } // Initialize the native API object + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath vocabPath:(NSString *)vocabPath { absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer = BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath), MakeString(vocabPath)); _GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer"); return [[TFLBertQuestionAnswerer alloc] initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())]; } // Calls the native API and converts C++ results into ObjC results - (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question { std::vector<QaAnswerCPP> results = _bertQuestionAnswerwer->Answer(MakeString(context), MakeString(question)); return [self arrayFromVector:results]; } }