Integrate BERT question answerer

The Task Library BertQuestionAnswerer API loads a Bert model and answers questions based on the content of a given passage. For more information, see the example for the Question-Answer model.

Key features of the BertQuestionAnswerer API

  • Takes two text inputs as question and context and outputs a list of possible answers.

  • Performs out-of-graph Wordpiece or Sentencepiece tokenizations on input text.

Supported BertQuestionAnswerer models

The following models are compatible with the BertNLClassifier API.

Run inference in Java

Step 1: Import Gradle dependency and other settings

Copy the .tflite model file to the assets directory of the Android module where the model will be run. Specify that the file should not be compressed, and add the TensorFlow Lite library to the module’s build.gradle file:

android {
    // Other settings

    // Specify tflite file should not be compressed for the app apk
    aaptOptions {
        noCompress "tflite"
    }

}

dependencies {
    // Other dependencies

    // Import the Task Text Library dependency
    implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.4'
}

Step 2: Run inference using the API

// Initialization
BertQuestionAnswererOptions options =
    BertQuestionAnswererOptions.builder()
        .setBaseOptions(BaseOptions.builder().setNumThreads(4).build())
        .build();
BertQuestionAnswerer answerer =
    BertQuestionAnswerer.createFromFileAndOptions(
        androidContext, modelFile, options);

// Run inference
List<QaAnswer> answers = answerer.answer(contextOfTheQuestion, questionToAsk);

See the source code for more details.

Run inference in Swift

Step 1: Import CocoaPods

Add the TensorFlowLiteTaskText pod in Podfile

target 'MySwiftAppWithTaskAPI' do
  use_frameworks!
  pod 'TensorFlowLiteTaskText', '~> 0.4.4'
end

Step 2: Run inference using the API

// Initialization
let mobileBertAnswerer = TFLBertQuestionAnswerer.questionAnswerer(
      modelPath: mobileBertModelPath)

// Run inference
let answers = mobileBertAnswerer.answer(
      context: context, question: question)

See the source code for more details.

Run inference in C++

// Initialization
BertQuestionAnswererOptions options;
options.mutable_base_options()->mutable_model_file()->set_file_name(model_path);
std::unique_ptr<BertQuestionAnswerer> answerer = BertQuestionAnswerer::CreateFromOptions(options).value();

// Run inference with your inputs, `context_of_question` and `question_to_ask`.
std::vector<QaAnswer> positive_results = answerer->Answer(context_of_question, question_to_ask);

See the source code for more details.

Run inference in Python

Step 1: Install the pip package

pip install tflite-support

Step 2: Using the model

# Imports
from tflite_support.task import text

# Initialization
answerer = text.BertQuestionAnswerer.create_from_file(model_path)

# Run inference
bert_qa_result = answerer.answer(context, question)

See the source code for more options to configure BertQuestionAnswerer.

Example results

Here is an example of the answer results of ALBERT model.

Context: "The Amazon rainforest, alternatively, the Amazon Jungle, also known in English as Amazonia, is a moist broadleaf tropical rainforest in the Amazon biome that covers most of the Amazon basin of South America. This basin encompasses 7,000,000 km2 (2,700,000 sq mi), of which 5,500,000 km2 (2,100,000 sq mi) are covered by the rainforest. This region includes territory belonging to nine nations."

Question: "Where is Amazon rainforest?"

Answers:

answer[0]:  'South America.'
logit: 1.84847, start_index: 39, end_index: 40
answer[1]:  'most of the Amazon basin of South America.'
logit: 1.2921, start_index: 34, end_index: 40
answer[2]:  'the Amazon basin of South America.'
logit: -0.0959535, start_index: 36, end_index: 40
answer[3]:  'the Amazon biome that covers most of the Amazon basin of South America.'
logit: -0.498558, start_index: 28, end_index: 40
answer[4]:  'Amazon basin of South America.'
logit: -0.774266, start_index: 37, end_index: 40

Try out the simple CLI demo tool for BertQuestionAnswerer with your own model and test data.

Model compatibility requirements

The BertQuestionAnswerer API expects a TFLite model with mandatory TFLite Model Metadata.

The Metadata should meet the following requirements:

  • input_process_units for Wordpiece/Sentencepiece Tokenizer

  • 3 input tensors with names "ids", "mask" and "segment_ids" for the output of the tokenizer

  • 2 output tensors with names "end_logits" and "start_logits" to indicate the answer's relative position in the context