나만의 Task API 빌드

TensorFlow Lite 작업 라이브러리는 사전 빌드된 기능을 제공합니다. C++, Android 및 iOS API는 기본 인프라를 추상화하는 동일한 인프라를 기반으로 하며 TensorFlow를 사용할 수도 있습니다. Task API 인프라를 확장하여 맞춤설정된 API 빌드 가능 모델이 기존 작업 라이브러리에서 지원되지 않는 경우

개요

Task API 인프라의 두 계층 구조는 하단 C++ 레이어 TFLite 런타임과 최상위 Java/ObjC 레이어를 캡슐화하는 것은 JNI 또는 래퍼를 통해 C++ 레이어와 통신합니다.

C++로만 모든 TensorFlow 로직을 구현하면 비용이 최소화되고 추론 성능을 향상시키고 플랫폼 전반의 전반적인 워크플로를 간소화합니다.

Task 클래스를 만들려면 BaseTaskApi TFLite 모델 인터페이스와 Task API 간의 변환 로직을 제공합니다. 인터페이스를 설치한 다음 Java/ObjC 유틸리티를 사용하여 해당 API를 만듭니다. 다음으로 바꿉니다. 모든 TensorFlow 세부정보를 볼 수 있으므로 앱에 TFLite 모델을 배포할 수 있습니다. 머신러닝 지식이 없어도 될 수 있습니다

TensorFlow Lite는 가장 널리 사용되는 몇 가지 사전 빌드된 API를 비전 및 NLP 작업. Cloud Build를 사용하여 Task API 인프라를 사용하여 다른 작업에 자체 API를 제공할 수 있습니다.

prebuilt_task_apis
그림 1. 사전 빌드된 Task API

Task API 인프라를 사용하여 자체 API 빌드

C++ API

모든 TFLite 세부정보는 C++ API로 구현됩니다. 다음 방법으로 API 객체 생성 팩토리 함수 중 하나를 사용하고 함수를 호출하여 모델 결과를 얻습니다. 사용됩니다.

사용 예시

다음은 C++ 함수를 사용한 예입니다. BertQuestionAnswerer 드림 대상: MobileBert됩니다.

  char kBertModelPath[] = "path/to/model.tflite";
  // Create the API from a model file
  std::unique_ptr<BertQuestionAnswerer> question_answerer =
      BertQuestionAnswerer::CreateFromFile(kBertModelPath);

  char kContext[] = ...; // context of a question to be answered
  char kQuestion[] = ...; // question to be answered
  // ask a question
  std::vector<QaAnswer> answers = question_answerer.Answer(kContext, kQuestion);
  // answers[0].text is the best answer

API 빌드

native_task_api
그림 2. 네이티브 태스크 API

API 객체를 빌드하려면 BaseTaskApi 드림

  • API I/O 결정 - API가 유사한 입력/출력을 노출해야 합니다. 확인할 수 있습니다 예: BertQuestionAnswerer는 두 문자열을 사용합니다. (std::string& context, std::string& question)를 입력으로 사용하고 출력은 std::vector<QaAnswer>인 가능한 답변과 확률의 벡터입니다. 이 BaseTaskApi템플릿 매개변수입니다. 템플릿 매개변수를 지정하면 BaseTaskApi::Infer 드림 함수가 올바른 입력/출력 유형을 갖게 됩니다. 이 함수는 API 클라이언트에서 직접 호출하기는 하지만, 내부에 이를 래핑하는 것이 좋습니다. 모델별 함수(이 경우 BertQuestionAnswerer::Answer)

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Model specific function delegating calls to BaseTaskApi::Infer
      std::vector<QaAnswer> Answer(const std::string& context, const std::string& question) {
        return Infer(context, question).value();
      }
    }
    
  • API I/O와 API I/O의 입력/출력 텐서 간에 모델 - 입력 및 출력 유형이 지정되면 서브클래스도 유형이 지정된 함수를 구현하고 BaseTaskApi::PreprocessBaseTaskApi::Postprocess 이 두 가지 함수는 입력출력 TFLite FlatBuffer에서 가져옴 서브클래스는 I/O 텐서로 변환될 수 있습니다. 전체 구현 보기 예시: BertQuestionAnswerer

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Convert API input into tensors
      absl::Status BertQuestionAnswerer::Preprocess(
        const std::vector<TfLiteTensor*>& input_tensors, // input tensors of the model
        const std::string& context, const std::string& query // InputType of the API
      ) {
        // Perform tokenization on input strings
        ...
        // Populate IDs, Masks and SegmentIDs to corresponding input tensors
        PopulateTensor(input_ids, input_tensors[0]);
        PopulateTensor(input_mask, input_tensors[1]);
        PopulateTensor(segment_ids, input_tensors[2]);
        return absl::OkStatus();
      }
    
      // Convert output tensors into API output
      StatusOr<std::vector<QaAnswer>> // OutputType
      BertQuestionAnswerer::Postprocess(
        const std::vector<const TfLiteTensor*>& output_tensors, // output tensors of the model
      ) {
        // Get start/end logits of prediction result from output tensors
        std::vector<float> end_logits;
        std::vector<float> start_logits;
        // output_tensors[0]: end_logits FLOAT[1, 384]
        PopulateVector(output_tensors[0], &end_logits);
        // output_tensors[1]: start_logits FLOAT[1, 384]
        PopulateVector(output_tensors[1], &start_logits);
        ...
        std::vector<QaAnswer::Pos> orig_results;
        // Look up the indices from vocabulary file and build results
        ...
        return orig_results;
      }
    }
    
  • API의 팩토리 함수 만들기 - 모델 파일과 OpResolver 초기화하려면 tflite::Interpreter TaskAPIFactory 드림 BaseTaskApi 인스턴스를 생성하는 유틸리티 함수를 제공합니다.

    모델과 관련된 파일도 제공해야 합니다. 예: BertQuestionAnswerer에는 tokenizer의 얻을 수 있습니다.

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Factory function to create the API instance
      StatusOr<std::unique_ptr<QuestionAnswerer>>
      BertQuestionAnswerer::CreateBertQuestionAnswerer(
          const std::string& path_to_model, // model to passed to TaskApiFactory
          const std::string& path_to_vocab  // additional model specific files
      ) {
        // Creates an API object by calling one of the utils from TaskAPIFactory
        std::unique_ptr<BertQuestionAnswerer> api_to_init;
        ASSIGN_OR_RETURN(
            api_to_init,
            core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
                path_to_model,
                absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
                kNumLiteThreads));
    
        // Perform additional model specific initializations
        // In this case building a vocabulary vector from the vocab file.
        api_to_init->InitializeVocab(path_to_vocab);
        return api_to_init;
      }
    }
    

Android API

Java/Kotlin 인터페이스를 정의하고 로직을 위임하여 Android API 만들기 JNI를 통해 C++ 레이어에 전달됩니다. Android API를 사용하려면 먼저 네이티브 API를 빌드해야 합니다.

사용 예시

다음은 Java를 사용하는 예입니다. BertQuestionAnswerer 드림 대상: MobileBert됩니다.

  String BERT_MODEL_FILE = "path/to/model.tflite";
  String VOCAB_FILE = "path/to/vocab.txt";
  // Create the API from a model file and vocabulary file
    BertQuestionAnswerer bertQuestionAnswerer =
        BertQuestionAnswerer.createBertQuestionAnswerer(
            ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE, VOCAB_FILE);

  String CONTEXT = ...; // context of a question to be answered
  String QUESTION = ...; // question to be answered
  // ask a question
  List<QaAnswer> answers = bertQuestionAnswerer.answer(CONTEXT, QUESTION);
  // answers.get(0).text is the best answer

API 빌드

android_task_api
그림 3. Android 작업 API

네이티브 API와 마찬가지로 API 객체를 빌드하려면 클라이언트가 다음 정보를 BaseTaskApi님, 모든 Java Task API에 JNI 처리를 제공합니다.

  • API I/O 결정 - 일반적으로 네이티브 인터페이스를 미러링합니다. 예: BertQuestionAnswerer(String context, String question)를 입력으로 사용합니다. 그리고 List<QaAnswer>를 출력합니다. 구현은 비공개 네이티브를 호출합니다. 함수를 사용할 수 있습니다. 단, C++에서 반환된 포인터인 추가 매개변수 long nativeHandle가 있다는 점이 다릅니다.

    class BertQuestionAnswerer extends BaseTaskApi {
      public List<QaAnswer> answer(String context, String question) {
        return answerNative(getNativeHandle(), context, question);
      }
    
      private static native List<QaAnswer> answerNative(
                                            long nativeHandle, // C++ pointer
                                            String context, String question // API I/O
                                           );
    
    }
    
  • API의 팩토리 함수 만들기 - 네이티브 팩토리도 미러링합니다. 함수도 포함되어야 하며, Android 팩토리 함수도 Context 를 참조하세요. 이 구현은 TaskJniUtils 드림 를 사용하여 해당 C++ API 객체를 빌드하고 해당 포인터를 BaseTaskApi 생성자.

      class BertQuestionAnswerer extends BaseTaskApi {
        private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME =
                                                  "bert_question_answerer_jni";
    
        // Extending super constructor by providing the
        // native handle(pointer of corresponding C++ API object)
        private BertQuestionAnswerer(long nativeHandle) {
          super(nativeHandle);
        }
    
        public static BertQuestionAnswerer createBertQuestionAnswerer(
                                            Context context, // Accessing Android files
                                            String pathToModel, String pathToVocab) {
          return new BertQuestionAnswerer(
              // The util first try loads the JNI module with name
              // BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, then opens two files,
              // converts them into ByteBuffer, finally ::initJniWithBertByteBuffers
              // is called with the buffer for a C++ API object pointer
              TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
                  context,
                  BertQuestionAnswerer::initJniWithBertByteBuffers,
                  BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
                  pathToModel,
                  pathToVocab));
        }
    
        // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
        // returns C++ API object pointer casted to long
        private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
    
      }
    
  • 네이티브 함수를 위한 JNI 모듈 구현 - 모든 자바 네이티브 메서드 JNI에서 해당하는 네이티브 함수를 호출하여 모듈을 마칩니다 팩토리 함수는 네이티브 API 객체를 만들고 포인터를 Java에 대한 긴 유형이라고 할 수 있습니다. 이후 Java API 호출에서 long 유형 포인터가 다시 JNI로 전달되고 네이티브 API 객체로 다시 캐스팅됩니다. 그런 다음 네이티브 API 결과가 Java 결과로 다시 변환됩니다.

    예를 들어, bert_question_answerer_jni 구현됩니다

      // Implements BertQuestionAnswerer::initJniWithBertByteBuffers
      extern "C" JNIEXPORT jlong JNICALL
      Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
          JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
        // Convert Java ByteBuffer object into a buffer that can be read by native factory functions
        absl::string_view model =
            GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
    
        // Creates the native API object
        absl::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
            BertQuestionAnswerer::CreateFromBuffer(
                model.data(), model.size());
        if (status.ok()) {
          // converts the object pointer to jlong and return to Java.
          return reinterpret_cast<jlong>(status->release());
        } else {
          return kInvalidPointer;
        }
      }
    
      // Implements BertQuestionAnswerer::answerNative
      extern "C" JNIEXPORT jobject JNICALL
      Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
      JNIEnv* env, jclass thiz, jlong native_handle, jstring context, jstring question) {
      // Convert long to native API object pointer
      QuestionAnswerer* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);
    
      // Calls the native API
      std::vector<QaAnswer> results = question_answerer->Answer(JStringToString(env, context),
                                             JStringToString(env, question));
    
      // Converts native result(std::vector<QaAnswer>) to Java result(List<QaAnswerer>)
      jclass qa_answer_class =
        env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer");
      jmethodID qa_answer_ctor =
        env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V");
      return ConvertVectorToArrayList<QaAnswer>(
        env, results,
        [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) {
          jstring text = env->NewStringUTF(ans.text.data());
          jobject qa_answer =
              env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start,
                             ans.pos.end, ans.pos.logit);
          env->DeleteLocalRef(text);
          return qa_answer;
        });
      }
    
      // Implements BaseTaskApi::deinitJni by delete the native object
      extern "C" JNIEXPORT void JNICALL Java_task_core_BaseTaskApi_deinitJni(
          JNIEnv* env, jobject thiz, jlong native_handle) {
        delete reinterpret_cast<QuestionAnswerer*>(native_handle);
      }
    

iOS API

네이티브 API 객체를 ObjC API 객체로 래핑하여 iOS API를 만듭니다. 이 생성된 API 객체를 ObjC 또는 Swift에서 사용할 수 있습니다. iOS API에는 네이티브 API를 먼저 빌드해야 합니다

사용 예시

이것은 ObjC를 사용한 예입니다. TFLBertQuestionAnswerer 드림 (MobileBert용) 있습니다.

  static let mobileBertModelPath = "path/to/model.tflite";
  // Create the API from a model file and vocabulary file
  let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
      modelPath: mobileBertModelPath)

  static let context = ...; // context of a question to be answered
  static let question = ...; // question to be answered
  // ask a question
  let answers = mobileBertAnswerer.answer(
      context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
  // answers.[0].text is the best answer

API 빌드

ios_task_api
그림 4. iOS Task API

iOS API는 네이티브 API 위에 있는 간단한 ObjC 래퍼입니다. API 빌드 방법 다음 단계를 따르세요.

  • ObjC 래퍼 정의: ObjC 클래스를 정의하고 이를 해당 네이티브 API 객체에 추가합니다. 네이티브 Swift에서 다음 기능을 지원하지 않기 때문에 종속 항목은 .mm 파일에만 나타날 수 있습니다. 상호 운용이 가능합니다.

    • .h 파일
      @interface TFLBertQuestionAnswerer : NSObject
    
      // Delegate calls to the native BertQuestionAnswerer::CreateBertQuestionAnswerer
      + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath
                                                    vocabPath:(NSString*)vocabPath
          NS_SWIFT_NAME(mobilebertQuestionAnswerer(modelPath:vocabPath:));
    
      // Delegate calls to the native BertQuestionAnswerer::Answer
      - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context
                                         question:(NSString*)question
          NS_SWIFT_NAME(answer(context:question:));
    }
    
    • .mm 파일
      using BertQuestionAnswererCPP = ::tflite::task::text::BertQuestionAnswerer;
    
      @implementation TFLBertQuestionAnswerer {
        // define an iVar for the native API object
        std::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer;
      }
    
      // Initialize the native API object
      + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath
                                              vocabPath:(NSString *)vocabPath {
        absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer =
            BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath),
                                                                MakeString(vocabPath));
        _GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer");
        return [[TFLBertQuestionAnswerer alloc]
            initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())];
      }
    
      // Calls the native API and converts C++ results into ObjC results
      - (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question {
        std::vector<QaAnswerCPP> results =
          _bertQuestionAnswerwer->Answer(MakeString(context), MakeString(question));
        return [self arrayFromVector:results];
      }
    }