Copyright 2024 The AI Edge Authors.
Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Run in Google Colab | View source on GitHub | Download notebook |
In this colab notebook, you'll learn how to use the TensorFlow Lite Model Maker library to train a custom object detection model capable of detecting salads within images on a mobile device.
The Model Maker library uses transfer learning to simplify the process of training a TensorFlow Lite model using a custom dataset. Retraining a TensorFlow Lite model with your own custom dataset reduces the amount of training data required and will shorten the training time.
You'll use the publicly available Salads dataset, which was created from the Open Images Dataset V4.
Each image in the dataset contains objects labeled as one of the following classes:
- Baked Good
- Cheese
- Salad
- Seafood
- Tomato
The dataset contains the bounding-boxes specifying where each object locates, together with the object's label.
Prerequisites
Install the required packages
Start by installing the required packages, including the Model Maker package from the GitHub repo and the pycocotools library you'll use for evaluation.
sudo apt -y install libportaudio2
pip install -q --use-deprecated=legacy-resolver tflite-model-maker
pip install -q pycocotools
pip install -q opencv-python-headless==4.1.2.30
pip uninstall -y tensorflow && pip install -q tensorflow==2.8.0
Import the required packages.
import numpy as np
import os
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)
Prepare the dataset
Here you'll use the same dataset as the AutoML quickstart.
The Salads dataset is available at:
gs://cloud-ml-data/img/openimage/csv/salads_ml_use.csv
.
It contains 175 images for training, 25 images for validation, and 25 images for testing. The dataset has five classes: Salad
, Seafood
, Tomato
, Baked goods
, Cheese
.
The dataset is provided in CSV format:
TRAINING,gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg,Salad,0.0,0.0954,,,0.977,0.957,,
VALIDATION,gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg,Seafood,0.0154,0.1538,,,1.0,0.802,,
TEST,gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg,Tomato,0.0,0.655,,,0.231,0.839,,
- Each row corresponds to an object localized inside a larger image, with each object specifically designated as test, train, or validation data. You'll learn more about what that means in a later stage in this notebook.
- The three lines included here indicate three distinct objects located inside the same image available at
gs://cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg
. - Each row has a different label:
Salad
,Seafood
,Tomato
, etc. - Bounding boxes are specified for each image using the top left and bottom right vertices.
If you want to know more about how to prepare your own CSV file and the minimum requirements for creating a valid dataset, see the Preparing your training data guide for more details.
If you are new to Google Cloud, you may wonder what the gs://
URL means. They are URLs of files stored on Google Cloud Storage (GCS). If you make your files on GCS public or authenticate your client, Model Maker can read those files similarly to your local files.
However, you don't need to keep your images on Google Cloud to use Model Maker. You can use a local path in your CSV file and Model Maker will just work.
Quickstart
There are six steps to training an object detection model:
Step 1. Choose an object detection model architecture.
This tutorial uses the EfficientDet-Lite0 model. EfficientDet-Lite[0-4] are a family of mobile/IoT-friendly object detection models derived from the EfficientDet architecture.
Here is the performance of each EfficientDet-Lite models compared to each others.
Model architecture | Size(MB)* | Latency(ms)** | Average Precision*** |
---|---|---|---|
EfficientDet-Lite0 | 4.4 | 37 | 25.69% |
EfficientDet-Lite1 | 5.8 | 49 | 30.55% |
EfficientDet-Lite2 | 7.2 | 69 | 33.97% |
EfficientDet-Lite3 | 11.4 | 116 | 37.70% |
EfficientDet-Lite4 | 19.9 | 260 | 41.96% |
* Size of the integer quantized models.
** Latency measured on Pixel 4 using 4 threads on CPU.
*** Average Precision is the mAP (mean Average Precision) on the COCO 2017 validation dataset.
spec = model_spec.get('efficientdet_lite0')
Step 2. Load the dataset.
Model Maker will take input data in the CSV format. Use the object_detector.DataLoader.from_csv
method to load the dataset and split them into the training, validation and test images.
- Training images: These images are used to train the object detection model to recognize salad ingredients.
- Validation images: These are images that the model didn't see during the training process. You'll use them to decide when you should stop the training, to avoid overfitting.
- Test images: These images are used to evaluate the final model performance.
You can load the CSV file directly from Google Cloud Storage, but you don't need to keep your images on Google Cloud to use Model Maker. You can specify a local CSV file on your computer, and Model Maker will work just fine.
train_data, validation_data, test_data = object_detector.DataLoader.from_csv('gs://cloud-ml-data/img/openimage/csv/salads_ml_use.csv')
Step 3. Train the TensorFlow model with the training data.
- The EfficientDet-Lite0 model uses
epochs = 50
by default, which means it will go through the training dataset 50 times. You can look at the validation accuracy during training and stop early to avoid overfitting. - Set
batch_size = 8
here so you will see that it takes 21 steps to go through the 175 images in the training dataset. - Set
train_whole_model=True
to fine-tune the whole model instead of just training the head layer to improve accuracy. The trade-off is that it may take longer to train the model.
model = object_detector.create(train_data, model_spec=spec, batch_size=8, train_whole_model=True, validation_data=validation_data)
Step 4. Evaluate the model with the test data.
After training the object detection model using the images in the training dataset, use the remaining 25 images in the test dataset to evaluate how the model performs against new data it has never seen before.
As the default batch size is 64, it will take 1 step to go through the 25 images in the test dataset.
The evaluation metrics are same as COCO.
model.evaluate(test_data)
Step 5. Export as a TensorFlow Lite model.
Export the trained object detection model to the TensorFlow Lite format by specifying which folder you want to export the quantized model to. The default post-training quantization technique is full integer quantization.
model.export(export_dir='.')
Step 6. Evaluate the TensorFlow Lite model.
Several factors can affect the model accuracy when exporting to TFLite:
- Quantization helps shrinking the model size by 4 times at the expense of some accuracy drop.
- The original TensorFlow model uses per-class non-max suppression (NMS) for post-processing, while the TFLite model uses global NMS that's much faster but less accurate. Keras outputs maximum 100 detections while tflite outputs maximum 25 detections.
Therefore you'll have to evaluate the exported TFLite model and compare its accuracy with the original TensorFlow model.
model.evaluate_tflite('model.tflite', test_data)
You can download the TensorFlow Lite model file using the left sidebar of Colab. Right-click on the model.tflite
file and choose Download
to download it to your local computer.
This model can be integrated into an Android or an iOS app using the ObjectDetector API of the TensorFlow Lite Task Library.
See the TFLite Object Detection sample app for more details on how the model is used in a working app.
(Optional) Test the TFLite model on your image
You can test the trained TFLite model using images from the internet.
- Replace the
INPUT_IMAGE_URL
below with your desired input image. - Adjust the
DETECTION_THRESHOLD
to change the sensitivity of the model. A lower threshold means the model will pickup more objects but there will also be more false detection. Meanwhile, a higher threshold means the model will only pickup objects that it has confidently detected.
Although it requires some of boilerplate code to run the model in Python at this moment, integrating the model into a mobile app only requires a few lines of code.
Load the trained TFLite model and define some visualization functions
import cv2
from PIL import Image
model_path = 'model.tflite'
# Load the labels into a list
classes = ['???'] * model.model_spec.config.num_classes
label_map = model.model_spec.config.label_map
for label_id, label_name in label_map.as_dict().items():
classes[label_id-1] = label_name
# Define a list of colors for visualization
COLORS = np.random.randint(0, 255, size=(len(classes), 3), dtype=np.uint8)
def preprocess_image(image_path, input_size):
"""Preprocess the input image to feed to the TFLite model"""
img = tf.io.read_file(image_path)
img = tf.io.decode_image(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.uint8)
original_image = img
resized_img = tf.image.resize(img, input_size)
resized_img = resized_img[tf.newaxis, :]
resized_img = tf.cast(resized_img, dtype=tf.uint8)
return resized_img, original_image
def detect_objects(interpreter, image, threshold):
"""Returns a list of detection results, each a dictionary of object info."""
signature_fn = interpreter.get_signature_runner()
# Feed the input image to the model
output = signature_fn(images=image)
# Get all outputs from the model
count = int(np.squeeze(output['output_0']))
scores = np.squeeze(output['output_1'])
classes = np.squeeze(output['output_2'])
boxes = np.squeeze(output['output_3'])
results = []
for i in range(count):
if scores[i] >= threshold:
result = {
'bounding_box': boxes[i],
'class_id': classes[i],
'score': scores[i]
}
results.append(result)
return results
def run_odt_and_draw_results(image_path, interpreter, threshold=0.5):
"""Run object detection on the input image and draw the detection results"""
# Load the input shape required by the model
_, input_height, input_width, _ = interpreter.get_input_details()[0]['shape']
# Load the input image and preprocess it
preprocessed_image, original_image = preprocess_image(
image_path,
(input_height, input_width)
)
# Run object detection on the input image
results = detect_objects(interpreter, preprocessed_image, threshold=threshold)
# Plot the detection results on the input image
original_image_np = original_image.numpy().astype(np.uint8)
for obj in results:
# Convert the object bounding box from relative coordinates to absolute
# coordinates based on the original image resolution
ymin, xmin, ymax, xmax = obj['bounding_box']
xmin = int(xmin * original_image_np.shape[1])
xmax = int(xmax * original_image_np.shape[1])
ymin = int(ymin * original_image_np.shape[0])
ymax = int(ymax * original_image_np.shape[0])
# Find the class index of the current object
class_id = int(obj['class_id'])
# Draw the bounding box and label on the image
color = [int(c) for c in COLORS[class_id]]
cv2.rectangle(original_image_np, (xmin, ymin), (xmax, ymax), color, 2)
# Make adjustments to make the label visible for all objects
y = ymin - 15 if ymin - 15 > 15 else ymin + 15
label = "{}: {:.0f}%".format(classes[class_id], obj['score'] * 100)
cv2.putText(original_image_np, label, (xmin, y),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# Return the final image
original_uint8 = original_image_np.astype(np.uint8)
return original_uint8
Run object detection and show the detection results
INPUT_IMAGE_URL = "https://storage.googleapis.com/cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg"
DETECTION_THRESHOLD = 0.3
TEMP_FILE = '/tmp/image.png'
!wget -q -O $TEMP_FILE $INPUT_IMAGE_URL
im = Image.open(TEMP_FILE)
im.thumbnail((512, 512), Image.ANTIALIAS)
im.save(TEMP_FILE, 'PNG')
# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
# Run inference and draw detection result on the local copy of the original file
detection_result_image = run_odt_and_draw_results(
TEMP_FILE,
interpreter,
threshold=DETECTION_THRESHOLD
)
# Show the detection result
Image.fromarray(detection_result_image)
(Optional) Compile For the Edge TPU
Now that you have a quantized EfficientDet Lite model, it is possible to compile and deploy to a Coral EdgeTPU.
Step 1. Install the EdgeTPU Compiler
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list
sudo apt-get update
sudo apt-get install edgetpu-compiler
Step 2. Select number of Edge TPUs, Compile
The EdgeTPU has 8MB of SRAM for caching model parameters (more info). This means that for models that are larger than 8MB, inference time will be increased in order to transfer over model parameters. One way to avoid this is Model Pipelining - splitting the model into segments that can have a dedicated EdgeTPU. This can significantly improve latency.
The below table can be used as a reference for the number of Edge TPUs to use - the larger models will not compile for a single TPU as the intermediate tensors can't fit in on-chip memory.
Model architecture | Minimum TPUs | Recommended TPUs |
---|---|---|
EfficientDet-Lite0 | 1 | 1 |
EfficientDet-Lite1 | 1 | 1 |
EfficientDet-Lite2 | 1 | 2 |
EfficientDet-Lite3 | 2 | 2 |
EfficientDet-Lite4 | 2 | 3 |
NUMBER_OF_TPUS = 1
!edgetpu_compiler model.tflite --num_segments=$NUMBER_OF_TPUS
Step 3. Download, Run Model
With the model(s) compiled, they can now be run on EdgeTPU(s) for object detection. First, download the compiled TensorFlow Lite model file using the left sidebar of Colab. Right-click on the model_edgetpu.tflite
file and choose Download
to download it to your local computer.
Now you can run the model in your preferred manner. Examples of detection include:
Advanced Usage
This section covers advanced usage topics like adjusting the model and the training hyperparameters.
Load the dataset
Load your own data
You can upload your own dataset to work through this tutorial. Upload your dataset by using the left sidebar in Colab.
If you prefer not to upload your dataset to the cloud, you can also locally run the library by following the guide.
Load your data with a different data format
The Model Maker library also supports the object_detector.DataLoader.from_pascal_voc
method to load data with PASCAL VOC format. makesense.ai and LabelImg are the tools that can annotate the image and save annotations as XML files in PASCAL VOC data format:
object_detector.DataLoader.from_pascal_voc(image_dir, annotations_dir, label_map={1: "person", 2: "notperson"})
Customize the EfficientDet model hyperparameters
The model and training pipeline parameters you can adjust are:
model_dir
: The location to save the model checkpoint files. If not set, a temporary directory will be used.steps_per_execution
: Number of steps per training execution.moving_average_decay
: Float. The decay to use for maintaining moving averages of the trained parameters.var_freeze_expr
: The regular expression to map the prefix name of variables to be frozen which means remaining the same during training. More specific, usere.match(var_freeze_expr, variable_name)
in the codebase to map the variables to be frozen.tflite_max_detections
: integer, 25 by default. The max number of output detections in the TFLite model.strategy
: A string specifying which distribution strategy to use. Accepted values are 'tpu', 'gpus', None. tpu' means to use TPUStrategy. 'gpus' mean to use MirroredStrategy for multi-gpus. If None, use TF default with OneDeviceStrategy.tpu
: The Cloud TPU to use for training. This should be either the name used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.use_xla
: Use XLA even if strategy is not tpu. If strategy is tpu, always use XLA, and this flag has no effect.profile
: Enable profile mode.debug
: Enable debug mode.
Other parameters that can be adjusted is shown in hparams_config.py.
For instance, you can set the var_freeze_expr='efficientnet'
which freezes the variables with name prefix efficientnet
(default is '(efficientnet|fpn_cells|resample_p6)'
). This allows the model to freeze untrainable variables and keep their value the same through training.
spec = model_spec.get('efficientdet_lite0')
spec.config.var_freeze_expr = 'efficientnet'
Change the Model Architecture
You can change the model architecture by changing the model_spec
. For instance, change the model_spec
to the EfficientDet-Lite4 model.
spec = model_spec.get('efficientdet_lite4')
Tune the training hyperparameters
The create
function is the driver function that the Model Maker library uses to create models. The model_spec
parameter defines the model specification. The object_detector.EfficientDetSpec
class is currently supported. The create
function comprises of the following steps:
- Creates the model for the object detection according to
model_spec
. - Trains the model. The default epochs and the default batch size are set by the
epochs
andbatch_size
variables in themodel_spec
object. You can also tune the training hyperparameters likeepochs
andbatch_size
that affect the model accuracy. For instance,
epochs
: Integer, 50 by default. More epochs could achieve better accuracy, but may lead to overfitting.batch_size
: Integer, 64 by default. The number of samples to use in one training step.train_whole_model
: Boolean, False by default. If true, train the whole model. Otherwise, only train the layers that do not matchvar_freeze_expr
.
For example, you can train with less epochs and only the head layer. You can increase the number of epochs for better results.
model = object_detector.create(train_data, model_spec=spec, epochs=10, validation_data=validation_data)
Export to different formats
The export formats can be one or a list of the following:
ExportFormat.TFLITE
ExportFormat.LABEL
ExportFormat.SAVED_MODEL
By default, it exports only the TensorFlow Lite model file containing the model metadata so that you can later use in an on-device ML application. The label file is embedded in metadata.
In many on-device ML application, the model size is an important factor. Therefore, it is recommended that you quantize the model to make it smaller and potentially run faster. As for EfficientDet-Lite models, full integer quantization is used to quantize the model by default. Please refer to Post-training quantization for more detail.
model.export(export_dir='.')
You can also choose to export other files related to the model for better examination. For instance, exporting both the saved model and the label file as follows:
model.export(export_dir='.', export_format=[ExportFormat.SAVED_MODEL, ExportFormat.LABEL])
Customize Post-training quantization on the TensorFlow Lite model
Post-training quantization is a conversion technique that can reduce model size and inference latency, while also improving CPU and hardware accelerator inference speed, with a little degradation in model accuracy. Thus, it's widely used to optimize the model.
Model Maker library applies a default post-training quantization technique when exporting the model. If you want to customize post-training quantization, Model Maker supports multiple post-training quantization options using QuantizationConfig as well. Let's take float16 quantization as an instance. First, define the quantization config.
config = QuantizationConfig.for_float16()
Then we export the TensorFlow Lite model with such configuration.
model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)
Read more
You can read our object detection example to learn technical details. For more information, please refer to:
- TensorFlow Lite Model Maker guide and API reference.
- Task Library: ObjectDetector for deployment.
- The end-to-end reference apps: Android, iOS, and Raspberry PI.