Tạo Task API của riêng bạn

Thư viện tác vụ TensorFlow Lite cung cấp sẵn API C++, Android và iOS trên cùng một cơ sở hạ tầng trừu tượng TensorFlow. Bạn có thể mở rộng cơ sở hạ tầng Task API để xây dựng các API tuỳ chỉnh nếu mô hình của bạn không được thư viện Tác vụ hiện tại hỗ trợ.

Tổng quan

Cơ sở hạ tầng API tác vụ có cấu trúc hai lớp: lớp C++ dưới cùng đóng gói thời gian chạy TFLite và lớp Java/ObjC trên cùng giao tiếp với lớp C++ thông qua JNI hoặc trình bao bọc.

Việc triển khai toàn bộ logic của TensorFlow trong C++ sẽ giúp giảm thiểu chi phí, tối đa hoá hiệu suất suy luận và đơn giản hoá quy trình làm việc tổng thể trên các nền tảng.

Để tạo một lớp Tác vụ, hãy mở rộng BaseTaskApi để cung cấp logic chuyển đổi giữa giao diện mô hình TFLite và Task API giao diện, sau đó sử dụng các tiện ích Java/ObjC để tạo API tương ứng. Bằng ẩn tất cả thông tin chi tiết của TensorFlow, bạn có thể triển khai mô hình TFLite trong ứng dụng của mình mà không cần kiến thức về máy học.

TensorFlow Lite cung cấp một số API tạo sẵn cho hầu hết các API phổ biến Nhiệm vụ về Tầm nhìn và NLP. Bạn có thể tạo API của riêng bạn cho các công việc khác bằng cơ sở hạ tầng Task API.

prebuilt_task_apis
Hình 1. API Tác vụ được tạo sẵn

Tạo API của riêng bạn bằng Task API ở dưới

API C++

Tất cả thông tin chi tiết về TFLite đều được triển khai trong API C++. Tạo đối tượng API bằng cách bằng một trong các hàm factory và nhận kết quả của mô hình bằng cách gọi các hàm được xác định trong giao diện.

Ví dụ về cách sử dụng

Dưới đây là một ví dụ về cách sử dụng C++ BertQuestionAnswerer với MobileBert.

  char kBertModelPath[] = "path/to/model.tflite";
  // Create the API from a model file
  std::unique_ptr<BertQuestionAnswerer> question_answerer =
      BertQuestionAnswerer::CreateFromFile(kBertModelPath);

  char kContext[] = ...; // context of a question to be answered
  char kQuestion[] = ...; // question to be answered
  // ask a question
  std::vector<QaAnswer> answers = question_answerer.Answer(kContext, kQuestion);
  // answers[0].text is the best answer

Xây dựng API

native_task_api
Hình 2. API Tác vụ gốc

Để tạo đối tượng API,bạn phải cung cấp những thông tin sau bằng cách mở rộng BaseTaskApi

  • Xác định API I/O – API của bạn phải hiển thị dữ liệu đầu vào/đầu ra tương tự trên các nền tảng khác nhau. ví dụ: BertQuestionAnswerer nhận 2 chuỗi (std::string& context, std::string& question) làm dữ liệu đầu vào và đầu ra vectơ xác suất và đáp án có thể có là std::vector<QaAnswer>. Chiến dịch này được thực hiện bằng cách chỉ định các kiểu dữ liệu tương ứng trong tập lệnh BaseTaskApi thông số mẫu. Với thông số mẫu đã chỉ định, BaseTaskApi::Infer sẽ có loại đầu vào/đầu ra chính xác. Chức năng này có thể được các ứng dụng API gọi trực tiếp, nhưng bạn nên gói mã này vào bên trong một hàm cụ thể cho mô hình, trong trường hợp này là BertQuestionAnswerer::Answer.

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Model specific function delegating calls to BaseTaskApi::Infer
      std::vector<QaAnswer> Answer(const std::string& context, const std::string& question) {
        return Infer(context, question).value();
      }
    }
    
  • Cung cấp logic chuyển đổi giữa API I/O và tensor đầu vào/đầu ra của mô hình – Với các kiểu đầu vào và đầu ra được chỉ định, các lớp con cũng cần phải triển khai các hàm đã nhập BaseTaskApi::PreprocessBaseTaskApi::Postprocess. Hai hàm này cung cấp đầu vàođầu ra từ TFLite FlatBuffer. Lớp con chịu trách nhiệm chỉ định các giá trị từ các tensor API I/O đến I/O. Xem toàn bộ quá trình triển khai ví dụ trong BertQuestionAnswerer.

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Convert API input into tensors
      absl::Status BertQuestionAnswerer::Preprocess(
        const std::vector<TfLiteTensor*>& input_tensors, // input tensors of the model
        const std::string& context, const std::string& query // InputType of the API
      ) {
        // Perform tokenization on input strings
        ...
        // Populate IDs, Masks and SegmentIDs to corresponding input tensors
        PopulateTensor(input_ids, input_tensors[0]);
        PopulateTensor(input_mask, input_tensors[1]);
        PopulateTensor(segment_ids, input_tensors[2]);
        return absl::OkStatus();
      }
    
      // Convert output tensors into API output
      StatusOr<std::vector<QaAnswer>> // OutputType
      BertQuestionAnswerer::Postprocess(
        const std::vector<const TfLiteTensor*>& output_tensors, // output tensors of the model
      ) {
        // Get start/end logits of prediction result from output tensors
        std::vector<float> end_logits;
        std::vector<float> start_logits;
        // output_tensors[0]: end_logits FLOAT[1, 384]
        PopulateVector(output_tensors[0], &end_logits);
        // output_tensors[1]: start_logits FLOAT[1, 384]
        PopulateVector(output_tensors[1], &start_logits);
        ...
        std::vector<QaAnswer::Pos> orig_results;
        // Look up the indices from vocabulary file and build results
        ...
        return orig_results;
      }
    }
    
  • Tạo các hàm factory của API – Tệp mô hình và OpResolver để khởi chạy tflite::Interpreter. TaskAPIFactory cung cấp các hàm hiệu dụng để tạo các thực thể BaseTaskApi.

    Bạn cũng phải cung cấp mọi tệp được liên kết với mô hình. ví dụ: BertQuestionAnswerer cũng có thể có một tệp bổ sung cho mã của trình tạo mã thông báo từ vựng.

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Factory function to create the API instance
      StatusOr<std::unique_ptr<QuestionAnswerer>>
      BertQuestionAnswerer::CreateBertQuestionAnswerer(
          const std::string& path_to_model, // model to passed to TaskApiFactory
          const std::string& path_to_vocab  // additional model specific files
      ) {
        // Creates an API object by calling one of the utils from TaskAPIFactory
        std::unique_ptr<BertQuestionAnswerer> api_to_init;
        ASSIGN_OR_RETURN(
            api_to_init,
            core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
                path_to_model,
                absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
                kNumLiteThreads));
    
        // Perform additional model specific initializations
        // In this case building a vocabulary vector from the vocab file.
        api_to_init->InitializeVocab(path_to_vocab);
        return api_to_init;
      }
    }
    

API Android

Tạo API Android bằng cách xác định giao diện Java/Kotlin và uỷ quyền logic sang lớp C++ thông qua JNI. Android API yêu cầu tạo API gốc trước tiên.

Ví dụ về cách sử dụng

Sau đây là một ví dụ về cách sử dụng Java BertQuestionAnswerer với MobileBert.

  String BERT_MODEL_FILE = "path/to/model.tflite";
  String VOCAB_FILE = "path/to/vocab.txt";
  // Create the API from a model file and vocabulary file
    BertQuestionAnswerer bertQuestionAnswerer =
        BertQuestionAnswerer.createBertQuestionAnswerer(
            ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE, VOCAB_FILE);

  String CONTEXT = ...; // context of a question to be answered
  String QUESTION = ...; // question to be answered
  // ask a question
  List<QaAnswer> answers = bertQuestionAnswerer.answer(CONTEXT, QUESTION);
  // answers.get(0).text is the best answer

Xây dựng API

android_task_api
Hình 3. API Tác vụ Android

Tương tự như API gốc, để tạo đối tượng API, ứng dụng cần cung cấp bằng cách mở rộng BaseTaskApi! cung cấp khả năng xử lý JNI cho tất cả các API Tác vụ Java.

  • Xác định API I/O – Quá trình này thường phản chiếu các giao diện gốc. ví dụ: BertQuestionAnswerer lấy (String context, String question) làm dữ liệu đầu vào và xuất ra List<QaAnswer>. Quá trình triển khai gọi một gốc riêng tư hàm có chữ ký tương tự, ngoại trừ việc có thêm tham số long nativeHandle, là con trỏ được trả về từ C++.

    class BertQuestionAnswerer extends BaseTaskApi {
      public List<QaAnswer> answer(String context, String question) {
        return answerNative(getNativeHandle(), context, question);
      }
    
      private static native List<QaAnswer> answerNative(
                                            long nativeHandle, // C++ pointer
                                            String context, String question // API I/O
                                           );
    
    }
    
  • Tạo các hàm factory của API – Việc này cũng phản ánh factory gốc ngoại trừ các hàm factory của Android cũng cần phải lấy Context để truy cập tệp. Quá trình triển khai gọi một trong các tiện ích trong TaskJniUtils để tạo đối tượng API C++ tương ứng và truyền con trỏ của đối tượng đó đến phương thức Hàm khởi tạo BaseTaskApi.

      class BertQuestionAnswerer extends BaseTaskApi {
        private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME =
                                                  "bert_question_answerer_jni";
    
        // Extending super constructor by providing the
        // native handle(pointer of corresponding C++ API object)
        private BertQuestionAnswerer(long nativeHandle) {
          super(nativeHandle);
        }
    
        public static BertQuestionAnswerer createBertQuestionAnswerer(
                                            Context context, // Accessing Android files
                                            String pathToModel, String pathToVocab) {
          return new BertQuestionAnswerer(
              // The util first try loads the JNI module with name
              // BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, then opens two files,
              // converts them into ByteBuffer, finally ::initJniWithBertByteBuffers
              // is called with the buffer for a C++ API object pointer
              TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
                  context,
                  BertQuestionAnswerer::initJniWithBertByteBuffers,
                  BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
                  pathToModel,
                  pathToVocab));
        }
    
        // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
        // returns C++ API object pointer casted to long
        private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
    
      }
    
  • Triển khai mô-đun JNI cho các hàm gốc – Tất cả các phương thức gốc Java được triển khai bằng cách gọi một hàm gốc tương ứng từ JNI . Các hàm factory sẽ tạo một đối tượng API gốc và trả về con trỏ dưới dạng kiểu dài đối với Java. Trong các lệnh gọi sau này đến API Java, đoạn mã dài con trỏ kiểu dữ liệu được chuyển lại cho JNI và truyền trở lại đối tượng API gốc. Sau đó, kết quả API gốc được chuyển đổi lại thành kết quả Java.

    Ví dụ: đây là cách bert_question_answerer_jni sẽ được triển khai.

      // Implements BertQuestionAnswerer::initJniWithBertByteBuffers
      extern "C" JNIEXPORT jlong JNICALL
      Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
          JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
        // Convert Java ByteBuffer object into a buffer that can be read by native factory functions
        absl::string_view model =
            GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
    
        // Creates the native API object
        absl::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
            BertQuestionAnswerer::CreateFromBuffer(
                model.data(), model.size());
        if (status.ok()) {
          // converts the object pointer to jlong and return to Java.
          return reinterpret_cast<jlong>(status->release());
        } else {
          return kInvalidPointer;
        }
      }
    
      // Implements BertQuestionAnswerer::answerNative
      extern "C" JNIEXPORT jobject JNICALL
      Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
      JNIEnv* env, jclass thiz, jlong native_handle, jstring context, jstring question) {
      // Convert long to native API object pointer
      QuestionAnswerer* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);
    
      // Calls the native API
      std::vector<QaAnswer> results = question_answerer->Answer(JStringToString(env, context),
                                             JStringToString(env, question));
    
      // Converts native result(std::vector<QaAnswer>) to Java result(List<QaAnswerer>)
      jclass qa_answer_class =
        env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer");
      jmethodID qa_answer_ctor =
        env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V");
      return ConvertVectorToArrayList<QaAnswer>(
        env, results,
        [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) {
          jstring text = env->NewStringUTF(ans.text.data());
          jobject qa_answer =
              env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start,
                             ans.pos.end, ans.pos.logit);
          env->DeleteLocalRef(text);
          return qa_answer;
        });
      }
    
      // Implements BaseTaskApi::deinitJni by delete the native object
      extern "C" JNIEXPORT void JNICALL Java_task_core_BaseTaskApi_deinitJni(
          JNIEnv* env, jobject thiz, jlong native_handle) {
        delete reinterpret_cast<QuestionAnswerer*>(native_handle);
      }
    

API iOS

Tạo API iOS bằng cách gói đối tượng API gốc vào đối tượng API ObjC. Chiến lược phát hành đĩa đơn tạo có thể dùng trong ObjC hoặc Swift. API iOS yêu cầu cần tạo trước tiên.

Ví dụ về cách sử dụng

Dưới đây là một ví dụ về cách sử dụng ObjC TFLBertQuestionAnswerer cho MobileBert trong Swift.

  static let mobileBertModelPath = "path/to/model.tflite";
  // Create the API from a model file and vocabulary file
  let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
      modelPath: mobileBertModelPath)

  static let context = ...; // context of a question to be answered
  static let question = ...; // question to be answered
  // ask a question
  let answers = mobileBertAnswerer.answer(
      context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
  // answers.[0].text is the best answer

Xây dựng API

ios_task_api
Hình 4. API tác vụ iOS

iOS API là một trình bao bọc ObjC đơn giản ở trên API gốc. Xây dựng API bằng cách bằng cách làm theo các bước dưới đây:

  • Xác định trình bao bọc ObjC – Xác định lớp ObjC và uỷ quyền triển khai cho đối tượng API gốc tương ứng. Ghi lại mã gốc các phần phụ thuộc chỉ có thể xuất hiện trong tệp .mm do Swift không thể tương tác với C++.

    • Tệp .h
      @interface TFLBertQuestionAnswerer : NSObject
    
      // Delegate calls to the native BertQuestionAnswerer::CreateBertQuestionAnswerer
      + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath
                                                    vocabPath:(NSString*)vocabPath
          NS_SWIFT_NAME(mobilebertQuestionAnswerer(modelPath:vocabPath:));
    
      // Delegate calls to the native BertQuestionAnswerer::Answer
      - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context
                                         question:(NSString*)question
          NS_SWIFT_NAME(answer(context:question:));
    }
    
    • Tệp .mm
      using BertQuestionAnswererCPP = ::tflite::task::text::BertQuestionAnswerer;
    
      @implementation TFLBertQuestionAnswerer {
        // define an iVar for the native API object
        std::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer;
      }
    
      // Initialize the native API object
      + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath
                                              vocabPath:(NSString *)vocabPath {
        absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer =
            BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath),
                                                                MakeString(vocabPath));
        _GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer");
        return [[TFLBertQuestionAnswerer alloc]
            initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())];
      }
    
      // Calls the native API and converts C++ results into ObjC results
      - (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question {
        std::vector<QaAnswerCPP> results =
          _bertQuestionAnswerwer->Answer(MakeString(context), MakeString(question));
        return [self arrayFromVector:results];
      }
    }