TensorFlow Lite Task Library には、ビルド済みの C++、Android、iOS の各 API を、Google Kubernetes Engine 説明します。Task API インフラストラクチャを拡張して、カスタマイズされた API を構築できる モデルが既存の Task ライブラリでサポートされていない場合は、
概要
Task API インフラストラクチャは 2 層構造(一番下の C++ レイヤ) TFLite ランタイムと最上位の Java/ObjC レイヤをカプセル化して JNI またはラッパーを介して C++ レイヤと通信します。
すべての TensorFlow ロジックを C++ のみに実装することで、コストを最小限に抑え、 プラットフォーム間のワークフロー全体の簡素化に役立ちます。
Task クラスを作成するには、 BaseTaskApi TFLite モデル インターフェースと Task API 間の変換ロジックを提供 Java や ObjC のユーティリティを使用して、対応する API を作成します。あり TensorFlow の詳細がすべて非表示になり、アプリに TFLite モデルをデプロイできます ML の知識がなくても簡単に処理できます
TensorFlow Lite は、一般的なユースケース向けにいくつかの事前構築済み API を提供 ビジョンと NLP タスク。独自の Task API インフラストラクチャを使用して、他のタスク用に独自の API を作成します。
Task API インフラを使用して独自の API を構築する
C++ API
TFLite の詳細はすべて C++ API で実装されています。API オブジェクトの作成 いずれかのファクトリ関数を使用し、関数を呼び出してモデルの結果を取得する 定義します。
使用例
これは C++ 言語による
BertQuestionAnswerer
MobileBert。
char kBertModelPath[] = "path/to/model.tflite";
// Create the API from a model file
std::unique_ptr<BertQuestionAnswerer> question_answerer =
BertQuestionAnswerer::CreateFromFile(kBertModelPath);
char kContext[] = ...; // context of a question to be answered
char kQuestion[] = ...; // question to be answered
// ask a question
std::vector<QaAnswer> answers = question_answerer.Answer(kContext, kQuestion);
// answers[0].text is the best answer
API のビルド
API オブジェクトを構築するには、
BaseTaskApi
API I/O を決定する - API は同様の入力/出力を公開する必要があります。 分析できます例:
BertQuestionAnswerer
は 2 つの文字列を取ります。 入力として(std::string& context, std::string& question)
を指定し、 可能性のある回答と確率のベクトルをstd::vector<QaAnswer>
とします。この これを行うには、BaseTaskApi
の テンプレート パラメータ。 テンプレート パラメータを指定すると、BaseTaskApi::Infer
関数の入出力の型が正しいことを確認します。この関数には、 API クライアントによって直接呼び出されますが、 モデル固有の関数。この場合はBertQuestionAnswerer::Answer
です。class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Model specific function delegating calls to BaseTaskApi::Infer std::vector<QaAnswer> Answer(const std::string& context, const std::string& question) { return Infer(context, question).value(); } }
API I/O と BigQuery の入出力テンソルとの間の変換ロジックを model - 入力型と出力型を指定すると、サブクラスも 型付き関数を実装する
BaseTaskApi::Preprocess
およびBaseTaskApi::Postprocess
。 この 2 つの関数により、 入力 および 出力 TFLiteFlatBuffer
から。サブクラスは、 API I/O テンソルから I/O テンソルへ移行されます。実装の全体を見る 例:BertQuestionAnswerer
。class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Convert API input into tensors absl::Status BertQuestionAnswerer::Preprocess( const std::vector<TfLiteTensor*>& input_tensors, // input tensors of the model const std::string& context, const std::string& query // InputType of the API ) { // Perform tokenization on input strings ... // Populate IDs, Masks and SegmentIDs to corresponding input tensors PopulateTensor(input_ids, input_tensors[0]); PopulateTensor(input_mask, input_tensors[1]); PopulateTensor(segment_ids, input_tensors[2]); return absl::OkStatus(); } // Convert output tensors into API output StatusOr<std::vector<QaAnswer>> // OutputType BertQuestionAnswerer::Postprocess( const std::vector<const TfLiteTensor*>& output_tensors, // output tensors of the model ) { // Get start/end logits of prediction result from output tensors std::vector<float> end_logits; std::vector<float> start_logits; // output_tensors[0]: end_logits FLOAT[1, 384] PopulateVector(output_tensors[0], &end_logits); // output_tensors[1]: start_logits FLOAT[1, 384] PopulateVector(output_tensors[1], &start_logits); ... std::vector<QaAnswer::Pos> orig_results; // Look up the indices from vocabulary file and build results ... return orig_results; } }
API のファクトリ関数を作成する - モデルファイルと
OpResolver
初期化するためにtflite::Interpreter
。TaskAPIFactory
には、BaseTaskApi インスタンスを作成するユーティリティ関数が用意されています。モデルに関連付けられているファイルも指定する必要があります。例:
BertQuestionAnswerer
には、トークナイザ用の追加ファイルを指定することもできます。 学びます。class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Factory function to create the API instance StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateBertQuestionAnswerer( const std::string& path_to_model, // model to passed to TaskApiFactory const std::string& path_to_vocab // additional model specific files ) { // Creates an API object by calling one of the utils from TaskAPIFactory std::unique_ptr<BertQuestionAnswerer> api_to_init; ASSIGN_OR_RETURN( api_to_init, core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( path_to_model, absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), kNumLiteThreads)); // Perform additional model specific initializations // In this case building a vocabulary vector from the vocab file. api_to_init->InitializeVocab(path_to_vocab); return api_to_init; } }
Android API
Java/Kotlin インターフェースを定義し、ロジックを委任して Android API を作成する C++ レイヤに渡します。Android API では、最初にネイティブ API をビルドする必要があります。
使用例
これは Java を使用したサンプルであり、
BertQuestionAnswerer
MobileBert。
String BERT_MODEL_FILE = "path/to/model.tflite";
String VOCAB_FILE = "path/to/vocab.txt";
// Create the API from a model file and vocabulary file
BertQuestionAnswerer bertQuestionAnswerer =
BertQuestionAnswerer.createBertQuestionAnswerer(
ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE, VOCAB_FILE);
String CONTEXT = ...; // context of a question to be answered
String QUESTION = ...; // question to be answered
// ask a question
List<QaAnswer> answers = bertQuestionAnswerer.answer(CONTEXT, QUESTION);
// answers.get(0).text is the best answer
API のビルド
ネイティブ API と同様に、API オブジェクトを構築するには、クライアントが
次のような情報を
BaseTaskApi
すべての Java Task API に対する JNI 処理を提供します。
API I/O を決定する - これは通常、ネイティブ インターフェースを反映します。例:
BertQuestionAnswerer
は入力として(String context, String question)
を受け取ります。List<QaAnswer>
を出力します。この実装では、プライベート ネイティブ 同様のシグネチャを持つ関数ですが、追加のパラメータlong nativeHandle
(C++ から返されるポインタ)があります。class BertQuestionAnswerer extends BaseTaskApi { public List<QaAnswer> answer(String context, String question) { return answerNative(getNativeHandle(), context, question); } private static native List<QaAnswer> answerNative( long nativeHandle, // C++ pointer String context, String question // API I/O ); }
API のファクトリ関数を作成する - ネイティブ ファクトリもミラーリングされます。 例外として、Android ファクトリ関数も例外を渡して
Context
必要ありません。この実装では、Terraform 内のユーティリティの 1 つがTaskJniUtils
対応する C++ API オブジェクトをビルドし、そのポインタをBaseTaskApi
コンストラクタ。class BertQuestionAnswerer extends BaseTaskApi { private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "bert_question_answerer_jni"; // Extending super constructor by providing the // native handle(pointer of corresponding C++ API object) private BertQuestionAnswerer(long nativeHandle) { super(nativeHandle); } public static BertQuestionAnswerer createBertQuestionAnswerer( Context context, // Accessing Android files String pathToModel, String pathToVocab) { return new BertQuestionAnswerer( // The util first try loads the JNI module with name // BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, then opens two files, // converts them into ByteBuffer, finally ::initJniWithBertByteBuffers // is called with the buffer for a C++ API object pointer TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( context, BertQuestionAnswerer::initJniWithBertByteBuffers, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, pathToModel, pathToVocab)); } // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer. // returns C++ API object pointer casted to long private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers); }
ネイティブ関数用の JNI モジュールの実装 - すべての Java ネイティブ メソッド 対応するネイティブ関数を JNI から呼び出すことで実装する。 説明します。ファクトリ関数はネイティブ API オブジェクトを作成し、 そのポインタを Java への long 型として扱うよう指示します。その後の Java API の呼び出しでは、 型ポインタが JNI に戻され、ネイティブ API オブジェクトにキャストバックされます。 その後、ネイティブ API の結果は Java の結果に戻されます。
たとえば、これは bert_question_answerer_jni 実装されます。
// Implements BertQuestionAnswerer::initJniWithBertByteBuffers extern "C" JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers( JNIEnv* env, jclass thiz, jobjectArray model_buffers) { // Convert Java ByteBuffer object into a buffer that can be read by native factory functions absl::string_view model = GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0)); // Creates the native API object absl::StatusOr<std::unique_ptr<QuestionAnswerer>> status = BertQuestionAnswerer::CreateFromBuffer( model.data(), model.size()); if (status.ok()) { // converts the object pointer to jlong and return to Java. return reinterpret_cast<jlong>(status->release()); } else { return kInvalidPointer; } } // Implements BertQuestionAnswerer::answerNative extern "C" JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative( JNIEnv* env, jclass thiz, jlong native_handle, jstring context, jstring question) { // Convert long to native API object pointer QuestionAnswerer* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle); // Calls the native API std::vector<QaAnswer> results = question_answerer->Answer(JStringToString(env, context), JStringToString(env, question)); // Converts native result(std::vector<QaAnswer>) to Java result(List<QaAnswerer>) jclass qa_answer_class = env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer"); jmethodID qa_answer_ctor = env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V"); return ConvertVectorToArrayList<QaAnswer>( env, results, [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) { jstring text = env->NewStringUTF(ans.text.data()); jobject qa_answer = env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start, ans.pos.end, ans.pos.logit); env->DeleteLocalRef(text); return qa_answer; }); } // Implements BaseTaskApi::deinitJni by delete the native object extern "C" JNIEXPORT void JNICALL Java_task_core_BaseTaskApi_deinitJni( JNIEnv* env, jobject thiz, jlong native_handle) { delete reinterpret_cast<QuestionAnswerer*>(native_handle); }
iOS API
ネイティブ API オブジェクトを ObjC API オブジェクトにラップして、iOS API を作成します。「 作成した API オブジェクトは ObjC でも Swift でもかまいません。iOS API では、 ネイティブ API を作成します。
使用例
ObjC を使用した例を次に示します。
TFLBertQuestionAnswerer
MobileBert
使用します。
static let mobileBertModelPath = "path/to/model.tflite";
// Create the API from a model file and vocabulary file
let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
modelPath: mobileBertModelPath)
static let context = ...; // context of a question to be answered
static let question = ...; // question to be answered
// ask a question
let answers = mobileBertAnswerer.answer(
context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
// answers.[0].text is the best answer
API のビルド
iOS API は、ネイティブ API のシンプルな ObjC ラッパーです。API のビルド方法 手順は次のとおりです。
ObjC ラッパーを定義する - ObjC クラスを定義して 対応するネイティブ API オブジェクトに追加します。ネイティブ Swift では依存関係が .mm ファイルに含まれないため、 相互運用性を確保できます。
- .h ファイル
@interface TFLBertQuestionAnswerer : NSObject // Delegate calls to the native BertQuestionAnswerer::CreateBertQuestionAnswerer + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath vocabPath:(NSString*)vocabPath NS_SWIFT_NAME(mobilebertQuestionAnswerer(modelPath:vocabPath:)); // Delegate calls to the native BertQuestionAnswerer::Answer - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context question:(NSString*)question NS_SWIFT_NAME(answer(context:question:)); }
- .mm ファイル
using BertQuestionAnswererCPP = ::tflite::task::text::BertQuestionAnswerer; @implementation TFLBertQuestionAnswerer { // define an iVar for the native API object std::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer; } // Initialize the native API object + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath vocabPath:(NSString *)vocabPath { absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer = BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath), MakeString(vocabPath)); _GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer"); return [[TFLBertQuestionAnswerer alloc] initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())]; } // Calls the native API and converts C++ results into ObjC results - (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question { std::vector<QaAnswerCPP> results = _bertQuestionAnswerwer->Answer(MakeString(context), MakeString(question)); return [self arrayFromVector:results]; } }