Thư viện tác vụ TensorFlow Lite cung cấp sẵn API C++, Android và iOS trên cùng một cơ sở hạ tầng trừu tượng TensorFlow. Bạn có thể mở rộng cơ sở hạ tầng Task API để xây dựng các API tuỳ chỉnh nếu mô hình của bạn không được thư viện Tác vụ hiện tại hỗ trợ.
Tổng quan
Cơ sở hạ tầng API tác vụ có cấu trúc hai lớp: lớp C++ dưới cùng đóng gói thời gian chạy TFLite và lớp Java/ObjC trên cùng giao tiếp với lớp C++ thông qua JNI hoặc trình bao bọc.
Việc triển khai toàn bộ logic của TensorFlow trong C++ sẽ giúp giảm thiểu chi phí, tối đa hoá hiệu suất suy luận và đơn giản hoá quy trình làm việc tổng thể trên các nền tảng.
Để tạo một lớp Tác vụ, hãy mở rộng BaseTaskApi để cung cấp logic chuyển đổi giữa giao diện mô hình TFLite và Task API giao diện, sau đó sử dụng các tiện ích Java/ObjC để tạo API tương ứng. Bằng ẩn tất cả thông tin chi tiết của TensorFlow, bạn có thể triển khai mô hình TFLite trong ứng dụng của mình mà không cần kiến thức về máy học.
TensorFlow Lite cung cấp một số API tạo sẵn cho hầu hết các API phổ biến Nhiệm vụ về Tầm nhìn và NLP. Bạn có thể tạo API của riêng bạn cho các công việc khác bằng cơ sở hạ tầng Task API.
Tạo API của riêng bạn bằng Task API ở dưới
API C++
Tất cả thông tin chi tiết về TFLite đều được triển khai trong API C++. Tạo đối tượng API bằng cách bằng một trong các hàm factory và nhận kết quả của mô hình bằng cách gọi các hàm được xác định trong giao diện.
Ví dụ về cách sử dụng
Dưới đây là một ví dụ về cách sử dụng C++
BertQuestionAnswerer
với
MobileBert.
char kBertModelPath[] = "path/to/model.tflite";
// Create the API from a model file
std::unique_ptr<BertQuestionAnswerer> question_answerer =
BertQuestionAnswerer::CreateFromFile(kBertModelPath);
char kContext[] = ...; // context of a question to be answered
char kQuestion[] = ...; // question to be answered
// ask a question
std::vector<QaAnswer> answers = question_answerer.Answer(kContext, kQuestion);
// answers[0].text is the best answer
Xây dựng API
Để tạo đối tượng API,bạn phải cung cấp những thông tin sau bằng cách mở rộng
BaseTaskApi
Xác định API I/O – API của bạn phải hiển thị dữ liệu đầu vào/đầu ra tương tự trên các nền tảng khác nhau. ví dụ:
BertQuestionAnswerer
nhận 2 chuỗi(std::string& context, std::string& question)
làm dữ liệu đầu vào và đầu ra vectơ xác suất và đáp án có thể có làstd::vector<QaAnswer>
. Chiến dịch này được thực hiện bằng cách chỉ định các kiểu dữ liệu tương ứng trong tập lệnhBaseTaskApi
thông số mẫu. Với thông số mẫu đã chỉ định,BaseTaskApi::Infer
sẽ có loại đầu vào/đầu ra chính xác. Chức năng này có thể được các ứng dụng API gọi trực tiếp, nhưng bạn nên gói mã này vào bên trong một hàm cụ thể cho mô hình, trong trường hợp này làBertQuestionAnswerer::Answer
.class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Model specific function delegating calls to BaseTaskApi::Infer std::vector<QaAnswer> Answer(const std::string& context, const std::string& question) { return Infer(context, question).value(); } }
Cung cấp logic chuyển đổi giữa API I/O và tensor đầu vào/đầu ra của mô hình – Với các kiểu đầu vào và đầu ra được chỉ định, các lớp con cũng cần phải triển khai các hàm đã nhập
BaseTaskApi::Preprocess
vàBaseTaskApi::Postprocess
. Hai hàm này cung cấp đầu vào và đầu ra từ TFLiteFlatBuffer
. Lớp con chịu trách nhiệm chỉ định các giá trị từ các tensor API I/O đến I/O. Xem toàn bộ quá trình triển khai ví dụ trongBertQuestionAnswerer
.class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Convert API input into tensors absl::Status BertQuestionAnswerer::Preprocess( const std::vector<TfLiteTensor*>& input_tensors, // input tensors of the model const std::string& context, const std::string& query // InputType of the API ) { // Perform tokenization on input strings ... // Populate IDs, Masks and SegmentIDs to corresponding input tensors PopulateTensor(input_ids, input_tensors[0]); PopulateTensor(input_mask, input_tensors[1]); PopulateTensor(segment_ids, input_tensors[2]); return absl::OkStatus(); } // Convert output tensors into API output StatusOr<std::vector<QaAnswer>> // OutputType BertQuestionAnswerer::Postprocess( const std::vector<const TfLiteTensor*>& output_tensors, // output tensors of the model ) { // Get start/end logits of prediction result from output tensors std::vector<float> end_logits; std::vector<float> start_logits; // output_tensors[0]: end_logits FLOAT[1, 384] PopulateVector(output_tensors[0], &end_logits); // output_tensors[1]: start_logits FLOAT[1, 384] PopulateVector(output_tensors[1], &start_logits); ... std::vector<QaAnswer::Pos> orig_results; // Look up the indices from vocabulary file and build results ... return orig_results; } }
Tạo các hàm factory của API – Tệp mô hình và
OpResolver
để khởi chạytflite::Interpreter
.TaskAPIFactory
cung cấp các hàm hiệu dụng để tạo các thực thể BaseTaskApi.Bạn cũng phải cung cấp mọi tệp được liên kết với mô hình. ví dụ:
BertQuestionAnswerer
cũng có thể có một tệp bổ sung cho mã của trình tạo mã thông báo từ vựng.class BertQuestionAnswerer : public BaseTaskApi< std::vector<QaAnswer>, // OutputType const std::string&, const std::string& // InputTypes > { // Factory function to create the API instance StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateBertQuestionAnswerer( const std::string& path_to_model, // model to passed to TaskApiFactory const std::string& path_to_vocab // additional model specific files ) { // Creates an API object by calling one of the utils from TaskAPIFactory std::unique_ptr<BertQuestionAnswerer> api_to_init; ASSIGN_OR_RETURN( api_to_init, core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( path_to_model, absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), kNumLiteThreads)); // Perform additional model specific initializations // In this case building a vocabulary vector from the vocab file. api_to_init->InitializeVocab(path_to_vocab); return api_to_init; } }
API Android
Tạo API Android bằng cách xác định giao diện Java/Kotlin và uỷ quyền logic sang lớp C++ thông qua JNI. Android API yêu cầu tạo API gốc trước tiên.
Ví dụ về cách sử dụng
Sau đây là một ví dụ về cách sử dụng Java
BertQuestionAnswerer
với
MobileBert.
String BERT_MODEL_FILE = "path/to/model.tflite";
String VOCAB_FILE = "path/to/vocab.txt";
// Create the API from a model file and vocabulary file
BertQuestionAnswerer bertQuestionAnswerer =
BertQuestionAnswerer.createBertQuestionAnswerer(
ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE, VOCAB_FILE);
String CONTEXT = ...; // context of a question to be answered
String QUESTION = ...; // question to be answered
// ask a question
List<QaAnswer> answers = bertQuestionAnswerer.answer(CONTEXT, QUESTION);
// answers.get(0).text is the best answer
Xây dựng API
Tương tự như API gốc, để tạo đối tượng API, ứng dụng cần cung cấp
bằng cách mở rộng
BaseTaskApi
!
cung cấp khả năng xử lý JNI cho tất cả các API Tác vụ Java.
Xác định API I/O – Quá trình này thường phản chiếu các giao diện gốc. ví dụ:
BertQuestionAnswerer
lấy(String context, String question)
làm dữ liệu đầu vào và xuất raList<QaAnswer>
. Quá trình triển khai gọi một gốc riêng tư hàm có chữ ký tương tự, ngoại trừ việc có thêm tham sốlong nativeHandle
, là con trỏ được trả về từ C++.class BertQuestionAnswerer extends BaseTaskApi { public List<QaAnswer> answer(String context, String question) { return answerNative(getNativeHandle(), context, question); } private static native List<QaAnswer> answerNative( long nativeHandle, // C++ pointer String context, String question // API I/O ); }
Tạo các hàm factory của API – Việc này cũng phản ánh factory gốc ngoại trừ các hàm factory của Android cũng cần phải lấy
Context
để truy cập tệp. Quá trình triển khai gọi một trong các tiện ích trongTaskJniUtils
để tạo đối tượng API C++ tương ứng và truyền con trỏ của đối tượng đó đến phương thức Hàm khởi tạoBaseTaskApi
.class BertQuestionAnswerer extends BaseTaskApi { private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "bert_question_answerer_jni"; // Extending super constructor by providing the // native handle(pointer of corresponding C++ API object) private BertQuestionAnswerer(long nativeHandle) { super(nativeHandle); } public static BertQuestionAnswerer createBertQuestionAnswerer( Context context, // Accessing Android files String pathToModel, String pathToVocab) { return new BertQuestionAnswerer( // The util first try loads the JNI module with name // BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, then opens two files, // converts them into ByteBuffer, finally ::initJniWithBertByteBuffers // is called with the buffer for a C++ API object pointer TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( context, BertQuestionAnswerer::initJniWithBertByteBuffers, BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, pathToModel, pathToVocab)); } // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer. // returns C++ API object pointer casted to long private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers); }
Triển khai mô-đun JNI cho các hàm gốc – Tất cả các phương thức gốc Java được triển khai bằng cách gọi một hàm gốc tương ứng từ JNI . Các hàm factory sẽ tạo một đối tượng API gốc và trả về con trỏ dưới dạng kiểu dài đối với Java. Trong các lệnh gọi sau này đến API Java, đoạn mã dài con trỏ kiểu dữ liệu được chuyển lại cho JNI và truyền trở lại đối tượng API gốc. Sau đó, kết quả API gốc được chuyển đổi lại thành kết quả Java.
Ví dụ: đây là cách bert_question_answerer_jni sẽ được triển khai.
// Implements BertQuestionAnswerer::initJniWithBertByteBuffers extern "C" JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers( JNIEnv* env, jclass thiz, jobjectArray model_buffers) { // Convert Java ByteBuffer object into a buffer that can be read by native factory functions absl::string_view model = GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0)); // Creates the native API object absl::StatusOr<std::unique_ptr<QuestionAnswerer>> status = BertQuestionAnswerer::CreateFromBuffer( model.data(), model.size()); if (status.ok()) { // converts the object pointer to jlong and return to Java. return reinterpret_cast<jlong>(status->release()); } else { return kInvalidPointer; } } // Implements BertQuestionAnswerer::answerNative extern "C" JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative( JNIEnv* env, jclass thiz, jlong native_handle, jstring context, jstring question) { // Convert long to native API object pointer QuestionAnswerer* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle); // Calls the native API std::vector<QaAnswer> results = question_answerer->Answer(JStringToString(env, context), JStringToString(env, question)); // Converts native result(std::vector<QaAnswer>) to Java result(List<QaAnswerer>) jclass qa_answer_class = env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer"); jmethodID qa_answer_ctor = env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V"); return ConvertVectorToArrayList<QaAnswer>( env, results, [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) { jstring text = env->NewStringUTF(ans.text.data()); jobject qa_answer = env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start, ans.pos.end, ans.pos.logit); env->DeleteLocalRef(text); return qa_answer; }); } // Implements BaseTaskApi::deinitJni by delete the native object extern "C" JNIEXPORT void JNICALL Java_task_core_BaseTaskApi_deinitJni( JNIEnv* env, jobject thiz, jlong native_handle) { delete reinterpret_cast<QuestionAnswerer*>(native_handle); }
API iOS
Tạo API iOS bằng cách gói đối tượng API gốc vào đối tượng API ObjC. Chiến lược phát hành đĩa đơn tạo có thể dùng trong ObjC hoặc Swift. API iOS yêu cầu cần tạo trước tiên.
Ví dụ về cách sử dụng
Dưới đây là một ví dụ về cách sử dụng ObjC
TFLBertQuestionAnswerer
cho MobileBert
trong Swift.
static let mobileBertModelPath = "path/to/model.tflite";
// Create the API from a model file and vocabulary file
let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
modelPath: mobileBertModelPath)
static let context = ...; // context of a question to be answered
static let question = ...; // question to be answered
// ask a question
let answers = mobileBertAnswerer.answer(
context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
// answers.[0].text is the best answer
Xây dựng API
iOS API là một trình bao bọc ObjC đơn giản ở trên API gốc. Xây dựng API bằng cách bằng cách làm theo các bước dưới đây:
Xác định trình bao bọc ObjC – Xác định lớp ObjC và uỷ quyền triển khai cho đối tượng API gốc tương ứng. Ghi lại mã gốc các phần phụ thuộc chỉ có thể xuất hiện trong tệp .mm do Swift không thể tương tác với C++.
- Tệp .h
@interface TFLBertQuestionAnswerer : NSObject // Delegate calls to the native BertQuestionAnswerer::CreateBertQuestionAnswerer + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath vocabPath:(NSString*)vocabPath NS_SWIFT_NAME(mobilebertQuestionAnswerer(modelPath:vocabPath:)); // Delegate calls to the native BertQuestionAnswerer::Answer - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context question:(NSString*)question NS_SWIFT_NAME(answer(context:question:)); }
- Tệp .mm
using BertQuestionAnswererCPP = ::tflite::task::text::BertQuestionAnswerer; @implementation TFLBertQuestionAnswerer { // define an iVar for the native API object std::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer; } // Initialize the native API object + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath vocabPath:(NSString *)vocabPath { absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer = BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath), MakeString(vocabPath)); _GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer"); return [[TFLBertQuestionAnswerer alloc] initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())]; } // Calls the native API and converts C++ results into ObjC results - (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question { std::vector<QaAnswerCPP> results = _bertQuestionAnswerwer->Answer(MakeString(context), MakeString(question)); return [self arrayFromVector:results]; } }