Text Searcher with TensorFlow Lite Model Maker

Licensed under the Apache License, Version 2.0 (the "License");

Run in Google Colab View source on GitHub Download notebook

In this colab notebook, you can learn how to use the TensorFlow Lite Model Maker library to create a TFLite Searcher model. You can use a text Searcher model to build Semantic Search or Smart Reply for your app. This type of model lets you take a text query and search for the most related entries in a text dataset, such as a database of web pages. The model returns a list of the smallest distance scoring entries in the dataset, including metadata you specify, such as URL, page title, or other text entry identifiers. After building this, you can deploy it onto devices (e.g. Android) using Task Library Searcher API to run inference with just a few lines of code.

This tutorial leverages CNN/DailyMail dataset as an instance to create the TFLite Searcher model. You can try with your own dataset with the compatible input comma separated value (CSV) format.

Text search using Scalable Nearest Neighbor

This tutorial uses the publicly available CNN/DailyMail non-anonymized summarization dataset, which was produced from the GitHub repo. This dataset contains over 300k news articles, which makes it a good dataset to build the Searcher model, and return various related news during model inference for a text query.

The text Searcher model in this example uses a ScaNN (Scalable Nearest Neighbors) index file that can search for similar items from a predefined database. ScaNN achieves state-of-the-art performance for efficient vector similarity search at scale.

Highlights and urls in this dataset are used in this colab to create the model:

  1. Highlights are the text for generating the embedding feature vectors and then used for search.
  2. Urls are the returned result shown to users after searching the related highlights.

This tutorial saves these data into the CSV file and then uses the CSV file to build the model. Here are several examples from the dataset.

Highlights Urls
Hawaiian Airlines again lands at No. 1 in on-time performance. The Airline Quality Rankings Report looks at the 14 largest U.S. airlines. ExpressJet
and American Airlines had the worst on-time performance. Virgin America had the best baggage handling; Southwest had lowest complaint rate.
http://www.cnn.com/2013/04/08/travel/airline-quality-report
European football's governing body reveals list of countries bidding to host 2020 finals. The 60th anniversary edition of the finals will be hosted by 13
countries. Thirty-two countries are considering bids to host 2020 matches. UEFA will announce host cities on September 25.
http://edition.cnn.com:80/2013/09/20/sport/football/football-euro-2020-bid-countries/index.html?
Once octopus-hunter Dylan Mayer has now also signed a petition of 5,000 divers banning their hunt at Seacrest Park. Decision by Washington
Department of Fish and Wildlife could take months.
http://www.dailymail.co.uk:80/news/article-2238423/Dylan-Mayer-Washington-considers-ban-Octopus-hunting-diver-caught-ate-Puget-Sound.html?
Galaxy was observed 420 million years after the Big Bang. found by NASA’s Hubble Space Telescope, Spitzer Space Telescope, and one of nature’s
own natural 'zoom lenses' in space.
http://www.dailymail.co.uk/sciencetech/article-2233883/The-furthest-object-seen-Record-breaking-image-shows-galaxy-13-3-BILLION-light-years-Earth.html

Setup

Start by installing the required packages, including the Model Maker package from the GitHub repo.

sudo apt -y install libportaudio2
pip install -q tflite-model-maker
pip install gdown

Import the required packages.

from tflite_model_maker import searcher

Prepare the dataset

This tutorial uses the dataset CNN / Daily Mail summarization dataset from the GitHub repo.

First, download the text and urls of cnn and dailymail and unzip them. If it failed to download from google drive, please wait a few minutes to try it again or download it manually and then upload it to the colab.

gdown https://drive.google.com/uc?id=0BwmD_VLjROrfTHk4NFg2SndKcjQ
gdown https://drive.google.com/uc?id=0BwmD_VLjROrfM1BxdkxVaTY2bWs

wget -O all_train.txt https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt
tar xzf cnn_stories.tgz
tar xzf dailymail_stories.tgz

Then, save the data into the CSV file that can be loaded into tflite_model_maker library. The code is based on the logic used to load this data in tensorflow_datasets. We can't use tensorflow_dataset directly since it doesn't contain urls which are used in this colab.

Since it takes a long time to process the data into embedding feature vectors for the whole dataset. Only first 5% stories of CNN and Daily Mail dataset are selected by default for demo purpose. You can adjust the fraction or try with the pre-built TFLite model with 50% stories of CNN and Daily Mail dataset to search as well.

Save the highlights and urls to the CSV file

Build the text Searcher model

Create a text Searcher model by loading a dataset, creating a model with the data and exporting the TFLite model.

Step 1. Load the dataset

Model Maker takes the text dataset and the corresponding metadata of each text string (such as urls in this example) in the CSV format. It embeds the text strings into feature vectors using the user-specified embedder model.

In this demo, we build the Searcher model using Universal Sentence Encoder, a state-of-the-art sentence embedding model which is already retrained from colab. The model is optimized for on-device inference performance, and only takes 6ms to embed a query string (measured on Pixel 6). Alternatively, you can use this quantized version, which is smaller but takes 38ms for each embedding.

wget -O universal_sentence_encoder.tflite https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/searcher/text_to_image_blogpost/text_embedder.tflite

Create a searcher.TextDataLoader instance and use data_loader.load_from_csv method to load the dataset. It takes ~10 minutes for this step since it generates the embedding feature vector for each text one by one. You can try to upload your own CSV file and load it to build the customized model as well.

Specify the name of text column and metadata column in the CSV file.

  • Text is used to generate the embedding feature vectors.
  • Metadata is the content to be shown when you search the certain text.

Here are the first 4 lines of the CNN-DailyMail CSV file generated above.

highlights urls
Syrian official: Obama climbed to the top of the tree, doesn't know how to get down. Obama sends a letter to the heads of the House and Senate. Obama
to seek congressional approval on military action against Syria. Aim is to determine whether CW were used, not by whom, says U.N. spokesman.
http://www.cnn.com/2013/08/31/world/meast/syria-civil-war/
Usain Bolt wins third gold of world championship. Anchors Jamaica to 4x100m relay victory. Eighth gold at the championships for Bolt. Jamaica double
up in women's 4x100m relay.
http://edition.cnn.com/2013/08/18/sport/athletics-bolt-jamaica-gold
The employee in agency's Kansas City office is among hundreds of "virtual" workers. The employee's travel to and from the mainland U.S. last year cost
more than $24,000. The telecommuting program, like all GSA practices, is under review.
http://www.cnn.com:80/2012/08/23/politics/gsa-hawaii-teleworking
NEW: A Canadian doctor says she was part of a team examining Harry Burkhart in 2010. NEW: Diagnosis: "autism, severe anxiety, post-traumatic stress
disorder and depression" Burkhart is also suspected in a German arson probe, officials say. Prosecutors believe the German national set a string of fires
in Los Angeles.
http://edition.cnn.com:80/2012/01/05/justice/california-arson/index.html?
data_loader = searcher.TextDataLoader.create("universal_sentence_encoder.tflite", l2_normalize=True)
data_loader.load_from_csv("cnn_dailymail.csv", text_column="highlights", metadata_column="urls")

For image use cases, you can create a searcher.ImageDataLoader instance and then use data_loader.load_from_folder to load images from the folder. The searcher.ImageDataLoader instance needs to be created by a TFLite embedder model because it will be leveraged to encode queries to feature vectors and be exported with the TFLite Searcher model. For instance:

data_loader = searcher.ImageDataLoader.create("mobilenet_v2_035_96_embedder_with_metadata.tflite")
data_loader.load_from_folder("food/")

Step 2. Create the Searcher model

  • Configure ScaNN options. See api doc for more details.
  • Create the Searcher model from data and ScaNN options. You can see the in-depth examination to learn more about the ScaNN algorithm.
scann_options = searcher.ScaNNOptions(
      distance_measure="dot_product",
      tree=searcher.Tree(num_leaves=140, num_leaves_to_search=4),
      score_ah=searcher.ScoreAH(dimensions_per_block=1, anisotropic_quantization_threshold=0.2))
model = searcher.Searcher.create_from_data(data_loader, scann_options)

In the above example, we define the following options:

  • distance_measure: we use "dot_product" to measure the distance between two embedding vectors. Note that we actually compute the negative dot product value to preserve the notion that "smaller is closer".

  • tree: the dataset is divided the dataset into 140 partitions (roughly the square root of the data size), and 4 of them are searched during retrieval, which is roughly 3% of the dataset.

  • score_ah: we quantize the float embeddings to int8 values with the same dimension to save space.

Step 3. Export the TFLite model

Then you can export the TFLite Searcher model.

model.export(
      export_filename="searcher.tflite",
      userinfo="",
      export_format=searcher.ExportFormat.TFLITE)

Test the TFLite model on your query

You can test the exported TFLite model using custom query text. To query text using the Searcher model, initialize the model and run a search with text phrase, as follows:

from tflite_support.task import text

# Initializes a TextSearcher object.
searcher = text.TextSearcher.create_from_file("searcher.tflite")

# Searches the input query.
results = searcher.search("The Airline Quality Rankings Report looks at the 14 largest U.S. airlines.")
print(results)

See the Task Library documentation for more information about how to integrate the model to various platforms.

Read more

For more information, please refer to: