独自の Task API を作成する

TensorFlow Lite Task Library には、事前構築済みの C++、Android、iOS の各 API を、インフラストラクチャを抽象化 説明します。Task API インフラストラクチャを拡張して、カスタマイズされた API を構築できる モデルが既存の Task ライブラリでサポートされていない場合は、

概要

Task API インフラストラクチャは 2 層構造(一番下の C++ レイヤ) TFLite ランタイムと最上位の Java/ObjC レイヤをカプセル化して JNI またはラッパーを介して C++ レイヤと通信します。

すべての TensorFlow ロジックを C++ のみに実装することで、コストを最小限に抑え、 プラットフォーム間のワークフロー全体の簡素化に役立ちます。

Task クラスを作成するには、 BaseTaskApi TFLite モデル インターフェースと Task API 間の変換ロジックを提供 Java や ObjC のユーティリティを使用して、対応する API を作成します。あり TensorFlow の詳細がすべて非表示になり、アプリに TFLite モデルをデプロイできます ML の知識がなくても簡単に処理できます

TensorFlow Lite は、一般的なユースケース向けにいくつかの事前構築済み API を提供 ビジョンと NLP タスク。独自の Task API インフラストラクチャを使用して、他のタスク用に独自の API を作成します。

prebuilt_task_apis
図 1. 事前構築済みのタスク API

Task API インフラを使用して独自の API を構築する

C++ API

TFLite の詳細はすべて C++ API で実装されています。API オブジェクトの作成 いずれかのファクトリ関数を使用し、関数を呼び出してモデルの結果を取得する 定義します。

使用例

これは C++ 言語による BertQuestionAnswerer MobileBert

  char kBertModelPath[] = "path/to/model.tflite";
  // Create the API from a model file
  std::unique_ptr<BertQuestionAnswerer> question_answerer =
      BertQuestionAnswerer::CreateFromFile(kBertModelPath);

  char kContext[] = ...; // context of a question to be answered
  char kQuestion[] = ...; // question to be answered
  // ask a question
  std::vector<QaAnswer> answers = question_answerer.Answer(kContext, kQuestion);
  // answers[0].text is the best answer

API のビルド

native_task_api
図 2. ネイティブ タスク API

API オブジェクトを構築するには、 BaseTaskApi

  • API I/O を決定する - API は同様の入力/出力を公開する必要があります。 分析できます例:BertQuestionAnswerer は 2 つの文字列を取ります。 入力として (std::string& context, std::string& question) を指定し、 可能性のある回答と確率のベクトルを std::vector<QaAnswer> とします。この これを行うには、BaseTaskApiテンプレート パラメータ。 テンプレート パラメータを指定すると、 BaseTaskApi::Infer 関数の入出力の型が正しいことを確認します。この関数には、 API クライアントによって直接呼び出されますが、 モデル固有の関数。この場合は BertQuestionAnswerer::Answer です。

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Model specific function delegating calls to BaseTaskApi::Infer
      std::vector<QaAnswer> Answer(const std::string& context, const std::string& question) {
        return Infer(context, question).value();
      }
    }
    
  • API I/O と BigQuery の入出力テンソルとの間の変換ロジックを model - 入力型と出力型を指定すると、サブクラスも 型付き関数を実装する BaseTaskApi::Preprocess および BaseTaskApi::Postprocess。 この 2 つの関数により、 入力 および 出力 TFLite FlatBuffer から。サブクラスは、 API I/O テンソルから I/O テンソル。実装の全体を見る 例: BertQuestionAnswerer

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Convert API input into tensors
      absl::Status BertQuestionAnswerer::Preprocess(
        const std::vector<TfLiteTensor*>& input_tensors, // input tensors of the model
        const std::string& context, const std::string& query // InputType of the API
      ) {
        // Perform tokenization on input strings
        ...
        // Populate IDs, Masks and SegmentIDs to corresponding input tensors
        PopulateTensor(input_ids, input_tensors[0]);
        PopulateTensor(input_mask, input_tensors[1]);
        PopulateTensor(segment_ids, input_tensors[2]);
        return absl::OkStatus();
      }
    
      // Convert output tensors into API output
      StatusOr<std::vector<QaAnswer>> // OutputType
      BertQuestionAnswerer::Postprocess(
        const std::vector<const TfLiteTensor*>& output_tensors, // output tensors of the model
      ) {
        // Get start/end logits of prediction result from output tensors
        std::vector<float> end_logits;
        std::vector<float> start_logits;
        // output_tensors[0]: end_logits FLOAT[1, 384]
        PopulateVector(output_tensors[0], &end_logits);
        // output_tensors[1]: start_logits FLOAT[1, 384]
        PopulateVector(output_tensors[1], &start_logits);
        ...
        std::vector<QaAnswer::Pos> orig_results;
        // Look up the indices from vocabulary file and build results
        ...
        return orig_results;
      }
    }
    
  • API のファクトリ関数を作成する - モデルファイルと OpResolver 初期化するために tflite::InterpreterTaskAPIFactory には、BaseTaskApi インスタンスを作成するユーティリティ関数が用意されています。

    モデルに関連付けられているファイルも指定する必要があります。例: BertQuestionAnswerer には、トークナイザ用の追加ファイルを指定することもできます。 学びます。

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Factory function to create the API instance
      StatusOr<std::unique_ptr<QuestionAnswerer>>
      BertQuestionAnswerer::CreateBertQuestionAnswerer(
          const std::string& path_to_model, // model to passed to TaskApiFactory
          const std::string& path_to_vocab  // additional model specific files
      ) {
        // Creates an API object by calling one of the utils from TaskAPIFactory
        std::unique_ptr<BertQuestionAnswerer> api_to_init;
        ASSIGN_OR_RETURN(
            api_to_init,
            core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
                path_to_model,
                absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
                kNumLiteThreads));
    
        // Perform additional model specific initializations
        // In this case building a vocabulary vector from the vocab file.
        api_to_init->InitializeVocab(path_to_vocab);
        return api_to_init;
      }
    }
    

Android API

Java/Kotlin インターフェースを定義し、ロジックを委任して Android API を作成する C++ レイヤに渡します。Android API では、最初にネイティブ API をビルドする必要があります。

使用例

これは Java を使用したサンプルであり、 BertQuestionAnswerer MobileBert

  String BERT_MODEL_FILE = "path/to/model.tflite";
  String VOCAB_FILE = "path/to/vocab.txt";
  // Create the API from a model file and vocabulary file
    BertQuestionAnswerer bertQuestionAnswerer =
        BertQuestionAnswerer.createBertQuestionAnswerer(
            ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE, VOCAB_FILE);

  String CONTEXT = ...; // context of a question to be answered
  String QUESTION = ...; // question to be answered
  // ask a question
  List<QaAnswer> answers = bertQuestionAnswerer.answer(CONTEXT, QUESTION);
  // answers.get(0).text is the best answer

API のビルド

android_task_api
図 3. Android Task API

ネイティブ API と同様に、API オブジェクトを構築するには、クライアントが 次のような情報を BaseTaskApi すべての Java Task API に対する JNI 処理を提供します。

  • API I/O を決定する - これは通常、ネイティブ インターフェースを反映します。例: BertQuestionAnswerer は入力として (String context, String question) を受け取ります。 List<QaAnswer> を出力します。この実装では、プライベート ネイティブ 同様のシグネチャを持つ関数ですが、追加のパラメータ long nativeHandle(C++ から返されるポインタ)があります。

    class BertQuestionAnswerer extends BaseTaskApi {
      public List<QaAnswer> answer(String context, String question) {
        return answerNative(getNativeHandle(), context, question);
      }
    
      private static native List<QaAnswer> answerNative(
                                            long nativeHandle, // C++ pointer
                                            String context, String question // API I/O
                                           );
    
    }
    
  • API のファクトリ関数を作成する - ネイティブ ファクトリもミラーリングされます。 例外として、Android ファクトリ関数も例外を渡して Context 必要ありません。この実装では、Terraform 内のユーティリティの 1 つが TaskJniUtils 対応する C++ API オブジェクトをビルドし、そのポインタを BaseTaskApi コンストラクタ。

      class BertQuestionAnswerer extends BaseTaskApi {
        private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME =
                                                  "bert_question_answerer_jni";
    
        // Extending super constructor by providing the
        // native handle(pointer of corresponding C++ API object)
        private BertQuestionAnswerer(long nativeHandle) {
          super(nativeHandle);
        }
    
        public static BertQuestionAnswerer createBertQuestionAnswerer(
                                            Context context, // Accessing Android files
                                            String pathToModel, String pathToVocab) {
          return new BertQuestionAnswerer(
              // The util first try loads the JNI module with name
              // BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, then opens two files,
              // converts them into ByteBuffer, finally ::initJniWithBertByteBuffers
              // is called with the buffer for a C++ API object pointer
              TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
                  context,
                  BertQuestionAnswerer::initJniWithBertByteBuffers,
                  BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
                  pathToModel,
                  pathToVocab));
        }
    
        // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
        // returns C++ API object pointer casted to long
        private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
    
      }
    
  • ネイティブ関数用の JNI モジュールの実装 - すべての Java ネイティブ メソッド 対応するネイティブ関数を JNI から呼び出すことで実装する。 説明します。ファクトリ関数はネイティブ API オブジェクトを作成し、 そのポインタを Java への long 型として扱うよう指示します。その後の Java API の呼び出しでは、 型ポインタが JNI に戻され、ネイティブ API オブジェクトにキャストバックされます。 その後、ネイティブ API の結果は Java の結果に戻されます。

    たとえば、これは bert_question_answerer_jni 実装されます。

      // Implements BertQuestionAnswerer::initJniWithBertByteBuffers
      extern "C" JNIEXPORT jlong JNICALL
      Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
          JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
        // Convert Java ByteBuffer object into a buffer that can be read by native factory functions
        absl::string_view model =
            GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
    
        // Creates the native API object
        absl::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
            BertQuestionAnswerer::CreateFromBuffer(
                model.data(), model.size());
        if (status.ok()) {
          // converts the object pointer to jlong and return to Java.
          return reinterpret_cast<jlong>(status->release());
        } else {
          return kInvalidPointer;
        }
      }
    
      // Implements BertQuestionAnswerer::answerNative
      extern "C" JNIEXPORT jobject JNICALL
      Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
      JNIEnv* env, jclass thiz, jlong native_handle, jstring context, jstring question) {
      // Convert long to native API object pointer
      QuestionAnswerer* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);
    
      // Calls the native API
      std::vector<QaAnswer> results = question_answerer->Answer(JStringToString(env, context),
                                             JStringToString(env, question));
    
      // Converts native result(std::vector<QaAnswer>) to Java result(List<QaAnswerer>)
      jclass qa_answer_class =
        env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer");
      jmethodID qa_answer_ctor =
        env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V");
      return ConvertVectorToArrayList<QaAnswer>(
        env, results,
        [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) {
          jstring text = env->NewStringUTF(ans.text.data());
          jobject qa_answer =
              env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start,
                             ans.pos.end, ans.pos.logit);
          env->DeleteLocalRef(text);
          return qa_answer;
        });
      }
    
      // Implements BaseTaskApi::deinitJni by delete the native object
      extern "C" JNIEXPORT void JNICALL Java_task_core_BaseTaskApi_deinitJni(
          JNIEnv* env, jobject thiz, jlong native_handle) {
        delete reinterpret_cast<QuestionAnswerer*>(native_handle);
      }
    

iOS API

ネイティブ API オブジェクトを ObjC API オブジェクトにラップして、iOS API を作成します。「 作成した API オブジェクトは ObjC でも Swift でもかまいません。iOS API では、 ネイティブ API を作成します。

使用例

ObjC を使用した例を次に示します。 TFLBertQuestionAnswerer MobileBert 使用します。

  static let mobileBertModelPath = "path/to/model.tflite";
  // Create the API from a model file and vocabulary file
  let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
      modelPath: mobileBertModelPath)

  static let context = ...; // context of a question to be answered
  static let question = ...; // question to be answered
  // ask a question
  let answers = mobileBertAnswerer.answer(
      context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
  // answers.[0].text is the best answer

API のビルド

ios_task_api
図 4. iOS Task API

iOS API は、ネイティブ API のシンプルな ObjC ラッパーです。API のビルド方法 手順は次のとおりです。

  • ObjC ラッパーを定義する - ObjC クラスを定義して 対応するネイティブ API オブジェクトに追加します。ネイティブ Swift では依存関係が .mm ファイルに含まれないため、 相互運用性を確保できます。

    • .h ファイル
      @interface TFLBertQuestionAnswerer : NSObject
    
      // Delegate calls to the native BertQuestionAnswerer::CreateBertQuestionAnswerer
      + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath
                                                    vocabPath:(NSString*)vocabPath
          NS_SWIFT_NAME(mobilebertQuestionAnswerer(modelPath:vocabPath:));
    
      // Delegate calls to the native BertQuestionAnswerer::Answer
      - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context
                                         question:(NSString*)question
          NS_SWIFT_NAME(answer(context:question:));
    }
    
    • .mm ファイル
      using BertQuestionAnswererCPP = ::tflite::task::text::BertQuestionAnswerer;
    
      @implementation TFLBertQuestionAnswerer {
        // define an iVar for the native API object
        std::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer;
      }
    
      // Initialize the native API object
      + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath
                                              vocabPath:(NSString *)vocabPath {
        absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer =
            BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath),
                                                                MakeString(vocabPath));
        _GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer");
        return [[TFLBertQuestionAnswerer alloc]
            initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())];
      }
    
      // Calls the native API and converts C++ results into ObjC results
      - (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question {
        std::vector<QaAnswerCPP> results =
          _bertQuestionAnswerwer->Answer(MakeString(context), MakeString(question));
        return [self arrayFromVector:results];
      }
    }