TensorFlow Lite 工作程式庫提供預先建構的 在同一個基礎架構上建構 C++、Android 和 iOS API TensorFlow。您可以將 Task API 基礎架構擴充至自訂 API 。
總覽
Task API 基礎架構的全層結構:底部 C++ 層 會封裝 TFLite 執行階段,以及 會透過 JNI 或包裝函式與 C++ 層通訊。
僅在 C++ 中實作所有 TensorFlow 邏輯,就能大幅降低成本、最大化 推論效能並簡化各平台的整體工作流程。
如要建立 Task 類別,請擴充 BaseTaskApi 提供 TFLite 模型介面和 Task API 之間的轉換邏輯 介面,然後使用 Java/ObjC 公用程式建立對應的 API。取代為 已隱藏所有 TensorFlow 詳細資料,您就可以在應用程式中部署 TFLite 模型 且不需要任何機器學習知識
TensorFlow Lite 提供了幾款預先建構的 API, 視覺和自然語言處理工作。例如 或自己的 API 處理其他工作
使用 Task API 基礎架構建構自己的 API
C++ API
所有 TFLite 詳細資料都在 C++ API 中實作。透過以下方式建立 API 物件: 使用其中一個工廠函式,並呼叫函式來取得模型結果 介面定義的內容
用法示範
以下是使用 C++ 的
BertQuestionAnswerer
敬上
的
MobileBert。
char kBertModelPath[] = "path/to/model.tflite";
// Create the API from a model file
std::unique_ptr<BertQuestionAnswerer> question_answerer =
BertQuestionAnswerer::CreateFromFile(kBertModelPath);
char kContext[] = ...; // context of a question to be answered
char kQuestion[] = ...; // question to be answered
// ask a question
std::vector<QaAnswer> answers = question_answerer.Answer(kContext, kQuestion);
// answers[0].text is the best answer
建構 API
如要建構 API 物件,您必須使用
BaseTaskApi
敬上
決定 API I/O:您的 API 應公開類似的輸入/輸出內容 在不同平台上放送相關廣告例如:
BertQuestionAnswerer
包含兩個字串(std::string& context, std::string& question)
做為輸入內容 可能的答案與機率向量,以std::vector<QaAnswer>
表示。這個 只要在BaseTaskApi
的 範本參數。 指定範本參數後,BaseTaskApi::Infer
敬上 函式都會有正確的輸入/輸出類型。這個函式可同時 直接。 特定模型專屬的函式,在本例中為BertQuestionAnswerer::Answer
。class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Model specific function delegating calls to BaseTaskApi::Infer std::vector<QaAnswer> Answer(const std::string& context, const std::string& question) { return Infer(context, question).value(); } }
提供 API I/O 與輸入/輸出張量之間的轉換邏輯 model - 指定輸入和輸出類型之後,子類別還需要 實作型別函式
BaseTaskApi::Preprocess
敬上 和BaseTaskApi::Postprocess
。 這兩項函式 輸入端 和 輸出內容 來自 TFLiteFlatBuffer
。子類別負責 從 API I/O 到 I/O 張量查看完整導入作業 範例BertQuestionAnswerer
。class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Convert API input into tensors absl::Status BertQuestionAnswerer::Preprocess( const std::vector<TfLiteTensor*>& input_tensors, // input tensors of the model const std::string& context, const std::string& query // InputType of the API ) { // Perform tokenization on input strings ... // Populate IDs, Masks and SegmentIDs to corresponding input tensors PopulateTensor(input_ids, input_tensors[0]); PopulateTensor(input_mask, input_tensors[1]); PopulateTensor(segment_ids, input_tensors[2]); return absl::OkStatus(); } // Convert output tensors into API output StatusOr<std::vector<QaAnswer>> // OutputType BertQuestionAnswerer::Postprocess( const std::vector<const TfLiteTensor*>& output_tensors, // output tensors of the model ) { // Get start/end logits of prediction result from output tensors std::vector<float> end_logits; std::vector<float> start_logits; // output_tensors[0]: end_logits FLOAT[1, 384] PopulateVector(output_tensors[0], &end_logits); // output_tensors[1]: start_logits FLOAT[1, 384] PopulateVector(output_tensors[1], &start_logits); ... std::vector<QaAnswer::Pos> orig_results; // Look up the indices from vocabulary file and build results ... return orig_results; } }
建立 API 的工廠函式:模型檔案和
OpResolver
敬上 以便初始化tflite::Interpreter
。TaskAPIFactory
敬上 提供用於建立 BaseTaskApi 執行個體的公用程式函式。您也必須提供與模型相關聯的任何檔案。例如:
BertQuestionAnswerer
也可以為符記化工具的 詞彙。class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Factory function to create the API instance StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateBertQuestionAnswerer( const std::string& path_to_model, // model to passed to TaskApiFactory const std::string& path_to_vocab // additional model specific files ) { // Creates an API object by calling one of the utils from TaskAPIFactory std::unique_ptr<BertQuestionAnswerer> api_to_init; ASSIGN_OR_RETURN( api_to_init, core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( path_to_model, absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), kNumLiteThreads)); // Perform additional model specific initializations // In this case building a vocabulary vector from the vocab file. api_to_init->InitializeVocab(path_to_vocab); return api_to_init; } }
Android API
定義 Java/Kotlin 介面並委派邏輯,以建立 Android API 透過 JNI 連線至 C++ 層必須先建構原生 API,才能使用 Android API。
用法示範
以下是使用 Java 的範例
BertQuestionAnswerer
敬上
的
MobileBert。
String BERT_MODEL_FILE = "path/to/model.tflite";
String VOCAB_FILE = "path/to/vocab.txt";
// Create the API from a model file and vocabulary file
BertQuestionAnswerer bertQuestionAnswerer =
BertQuestionAnswerer.createBertQuestionAnswerer(
ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE, VOCAB_FILE);
String CONTEXT = ...; // context of a question to be answered
String QUESTION = ...; // question to be answered
// ask a question
List<QaAnswer> answers = bertQuestionAnswerer.answer(CONTEXT, QUESTION);
// answers.get(0).text is the best answer
建構 API
與原生 API 類似,如要建構 API 物件,用戶端必須提供
將下列資訊延伸
BaseTaskApi
、
,為所有 Java Task API 提供 JNI 處理功能。
決定 API I/O:這通常會反映原生介面。例如:
BertQuestionAnswerer
使用(String context, String question)
做為輸入內容 並輸出List<QaAnswer>
實作會呼叫不公開的原生程式碼 函式具有類似簽章,但具有額外參數long nativeHandle
,也就是 C++ 傳回的指標。class BertQuestionAnswerer extends BaseTaskApi { public List<QaAnswer> answer(String context, String question) { return answerNative(getNativeHandle(), context, question); } private static native List<QaAnswer> answerNative( long nativeHandle, // C++ pointer String context, String question // API I/O ); }
建立 API 的工廠函式:這也會與原生工廠物件進行比對 但除了 Android 工廠函式以外,也需要
Context
敬上 提供檔案存取權實作會呼叫TaskJniUtils
敬上 建構對應的 C++ API 物件,並將其指標傳送至BaseTaskApi
建構函式。class BertQuestionAnswerer extends BaseTaskApi { private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "bert_question_answerer_jni"; // Extending super constructor by providing the // native handle(pointer of corresponding C++ API object) private BertQuestionAnswerer(long nativeHandle) { super(nativeHandle); } public static BertQuestionAnswerer createBertQuestionAnswerer( Context context, // Accessing Android files String pathToModel, String pathToVocab) { return new BertQuestionAnswerer( // The util first try loads the JNI module with name // BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, then opens two files, // converts them into ByteBuffer, finally ::initJniWithBertByteBuffers // is called with the buffer for a C++ API object pointer TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( context, BertQuestionAnswerer::initJniWithBertByteBuffers, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, pathToModel, pathToVocab)); } // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer. // returns C++ API object pointer casted to long private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers); }
實作適用於原生函式的 JNI 模組:所有 Java 原生方法 實作方法是從 JNI 呼叫對應的原生函式 後續課程我們將逐一介紹 預先訓練的 API、AutoML 和自訂訓練工廠函式會建立原生 API 物件,並傳回 做為 Java 的長型別指標在稍後呼叫 Java API 中 型別指標會傳回回 JNI,並轉換回原生 API 物件。 接著,原生 API 結果會轉換回 Java 結果。
舉例來說 bert_question_answerer_jni
// Implements BertQuestionAnswerer::initJniWithBertByteBuffers extern "C" JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers( JNIEnv* env, jclass thiz, jobjectArray model_buffers) { // Convert Java ByteBuffer object into a buffer that can be read by native factory functions absl::string_view model = GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0)); // Creates the native API object absl::StatusOr<std::unique_ptr<QuestionAnswerer>> status = BertQuestionAnswerer::CreateFromBuffer( model.data(), model.size()); if (status.ok()) { // converts the object pointer to jlong and return to Java. return reinterpret_cast<jlong>(status->release()); } else { return kInvalidPointer; } } // Implements BertQuestionAnswerer::answerNative extern "C" JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative( JNIEnv* env, jclass thiz, jlong native_handle, jstring context, jstring question) { // Convert long to native API object pointer QuestionAnswerer* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle); // Calls the native API std::vector<QaAnswer> results = question_answerer->Answer(JStringToString(env, context), JStringToString(env, question)); // Converts native result(std::vector<QaAnswer>) to Java result(List<QaAnswerer>) jclass qa_answer_class = env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer"); jmethodID qa_answer_ctor = env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V"); return ConvertVectorToArrayList<QaAnswer>( env, results, [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) { jstring text = env->NewStringUTF(ans.text.data()); jobject qa_answer = env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start, ans.pos.end, ans.pos.logit); env->DeleteLocalRef(text); return qa_answer; }); } // Implements BaseTaskApi::deinitJni by delete the native object extern "C" JNIEXPORT void JNICALL Java_task_core_BaseTaskApi_deinitJni( JNIEnv* env, jobject thiz, jlong native_handle) { delete reinterpret_cast<QuestionAnswerer*>(native_handle); }
iOS API
將原生 API 物件納入 ObjC API 物件中,藉此建立 iOS API。 建立的 API 物件可用於 ObjC 或 Swift。iOS API 需要 初次建構的 API
用法示範
以下是使用 ObjC 的範例
TFLBertQuestionAnswerer
敬上
適用於 MobileBert
CANNOT TRANSLATE
static let mobileBertModelPath = "path/to/model.tflite";
// Create the API from a model file and vocabulary file
let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
modelPath: mobileBertModelPath)
static let context = ...; // context of a question to be answered
static let question = ...; // question to be answered
// ask a question
let answers = mobileBertAnswerer.answer(
context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
// answers.[0].text is the best answer
建構 API
iOS API 是位於原生 API 上的簡易 ObjC 包裝函式。建構 API: 步驟如下:
定義 ObjC 包裝函式 - 定義 ObjC 類別並委派 實作至對應的原生 API 物件。注意原生檔案 由於 Swift 無法 與 C++ 的互通性
- .h 檔案
@interface TFLBertQuestionAnswerer : NSObject // Delegate calls to the native BertQuestionAnswerer::CreateBertQuestionAnswerer + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath vocabPath:(NSString*)vocabPath NS_SWIFT_NAME(mobilebertQuestionAnswerer(modelPath:vocabPath:)); // Delegate calls to the native BertQuestionAnswerer::Answer - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context question:(NSString*)question NS_SWIFT_NAME(answer(context:question:)); }
- .mm 檔案
using BertQuestionAnswererCPP = ::tflite::task::text::BertQuestionAnswerer; @implementation TFLBertQuestionAnswerer { // define an iVar for the native API object std::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer; } // Initialize the native API object + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath vocabPath:(NSString *)vocabPath { absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer = BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath), MakeString(vocabPath)); _GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer"); return [[TFLBertQuestionAnswerer alloc] initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())]; } // Calls the native API and converts C++ results into ObjC results - (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question { std::vector<QaAnswerCPP> results = _bertQuestionAnswerwer->Answer(MakeString(context), MakeString(question)); return [self arrayFromVector:results]; } }