Run in Colab | View on GitHub |
License information
# Copyright 2023 The MediaPipe Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
The MediaPipe object detection solution provides several models you can use immediately for machine learning (ML) in your application. However, if you need to detect objects not covered by the provided models, you can customize any of the provided models with your own data and MediaPipe Model Maker. This model modification tool rebuilds the model using data you provide. This method is faster than training a new model and can produce a model that is more useful for your specific application.
The following sections show you how to use Model Maker to retrain a pre-built model for object detection with your own data, which you can then use with the MediaPipe Object Detector. The example retrains a general purpose object detection model to detect android figurines in images.
Setup
This section describes key steps for setting up your development environment to retrain a model. These instructions describe how to update a model using Google Colab, and you can also use Python in your own development environment. For general information on setting up your development environment for using MediaPipe, including platform version requirements, see the Setup guide for Python.
Attention: This MediaPipe Solutions Preview is an early release. Learn more.
To install the libraries for customizing a model, run the following commands:
python --version
pip install --upgrade pip
pip install mediapipe-model-maker
Use the following code to import the required Python classes:
from google.colab import files
import os
import json
import tensorflow as tf
assert tf.__version__.startswith('2')
from mediapipe_model_maker import object_detector
Prepare data
Retraining a model for object detection requires a dataset that includes the items, or classes, that you want the completed model to be able to identify. You can do this by trimming down a public dataset to only the classes that are relevant to your usecase, compiling your own dataset, or some combination of both, The dataset can be significantly smaller than what would be required to train a new model. For example, the COCO dataset used to train many reference models contains hundreds of thousands of images with 91 classes of objects. Transfer learning with Model Maker can retrain an existing model with a smaller dataset and still perform well, depending on your inference accuracy goals. These instructions use a smaller dataset containing 2 types of android figurines, or 2 classes, with 62 total training images.
To download the example dataset, use the following code:
!wget https://storage.googleapis.com/mediapipe-tasks/object_detector/android_figurine.zip
!unzip android_figurine.zip
train_dataset_path = "android_figurine/train"
validation_dataset_path = "android_figurine/validation"
This code stores the dataset at the directory location android_figurine
. The directory contains two subdirectories for the training and validation datasets, located in android_figurine/train
and android_figurine/validation
respectively. Each of the train and validation datasets follow the COCO Dataset format described below.
Supported dataset formats
Model Maker Object Detection API supports reading the following dataset formats:
COCO format
The COCO dataset format has a data
directory which stores all of the images and a single labels.json
file which contains the object annotations for all images.
<dataset_dir>/
data/
<img0>.<jpg/jpeg>
<img1>.<jpg/jpeg>
...
labels.json
where labels.json
is formatted as:
{
"categories":[
{"id":1, "name":<cat1_name>},
...
],
"images":[
{"id":0, "file_name":"<img0>.<jpg/jpeg>"},
...
],
"annotations":[
{"id":0, "image_id":0, "category_id":1, "bbox":[x-top left, y-top left, width, height]},
...
]
}
PASCAL VOC format
The PASCAL VOC dataset format also has a data
directory which stores all of the images, however the annotations are split up per image into corresponding xml files in the Annotations
directory.
<dataset_dir>/
data/
<file0>.<jpg/jpeg>
...
Annotations/
<file0>.xml
...
where the xml files are formatted as:
<annotation>
<filename>file0.jpg</filename>
<object>
<name>kangaroo</name>
<bndbox>
<xmin>233</xmin>
<ymin>89</ymin>
<xmax>386</xmax>
<ymax>262</ymax>
</bndbox>
</object>
<object>
...
</object>
...
</annotation>
Review dataset
Verify the dataset content by printing the categories from the labels.json
file. There should be 3 total categories. Index 0 is always set to be the background
class which may be unused in the dataset. There should be two non-background categories of android
and pig_android
.
with open(os.path.join(train_dataset_path, "labels.json"), "r") as f:
labels_json = json.load(f)
for category_item in labels_json["categories"]:
print(f"{category_item['id']}: {category_item['name']}")
To better understand the dataset, plot a couple of example images along with their bounding boxes.
Visualize the training dataset
import matplotlib.pyplot as plt
from matplotlib import patches, text, patheffects
from collections import defaultdict
import math
def draw_outline(obj):
obj.set_path_effects([patheffects.Stroke(linewidth=4, foreground='black'), patheffects.Normal()])
def draw_box(ax, bb):
patch = ax.add_patch(patches.Rectangle((bb[0],bb[1]), bb[2], bb[3], fill=False, edgecolor='red', lw=2))
draw_outline(patch)
def draw_text(ax, bb, txt, disp):
text = ax.text(bb[0],(bb[1]-disp),txt,verticalalignment='top'
,color='white',fontsize=10,weight='bold')
draw_outline(text)
def draw_bbox(ax, annotations_list, id_to_label, image_shape):
for annotation in annotations_list:
cat_id = annotation["category_id"]
bbox = annotation["bbox"]
draw_box(ax, bbox)
draw_text(ax, bbox, id_to_label[cat_id], image_shape[0] * 0.05)
def visualize(dataset_folder, max_examples=None):
with open(os.path.join(dataset_folder, "labels.json"), "r") as f:
labels_json = json.load(f)
images = labels_json["images"]
cat_id_to_label = {item["id"]:item["name"] for item in labels_json["categories"]}
image_annots = defaultdict(list)
for annotation_obj in labels_json["annotations"]:
image_id = annotation_obj["image_id"]
image_annots[image_id].append(annotation_obj)
if max_examples is None:
max_examples = len(image_annots.items())
n_rows = math.ceil(max_examples / 3)
fig, axs = plt.subplots(n_rows, 3, figsize=(24, n_rows*8)) # 3 columns(2nd index), 8x8 for each image
for ind, (image_id, annotations_list) in enumerate(list(image_annots.items())[:max_examples]):
ax = axs[ind//3, ind%3]
img = plt.imread(os.path.join(dataset_folder, "images", images[image_id]["file_name"]))
ax.imshow(img)
draw_bbox(ax, annotations_list, cat_id_to_label, img.shape)
plt.show()
visualize(train_dataset_path, 9)
Create dataset
The Dataset class has two methods for loading in COCO or PASCAL VOC datasets:
Dataset.from_coco_folder
Dataset.from_pascal_voc_folder
Since the android_figurines dataset is in the COCO dataset format, use the from_coco_folder
method to load the dataset located at train_dataset_path
and validation_dataset_path
. When loading the dataset, the data will be parsed from the provided path and converted into a standardized TFRecord format which is cached for later use. You should create a cache_dir
location and reuse it for all your training to avoid saving multiple caches of the same dataset.
train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir="/tmp/od_data/train")
validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir="/tmp/od_data/validation")
print("train_data size: ", train_data.size)
print("validation_data size: ", validation_data.size)
Retrain model
Once you have completed preparing your data, you can begin retraining a model to recognize the new objects, or classes, defined by your training data. The instructions below use the data prepared in the previous section to retrain an image classification model to recognize the two types of android figurines.
Set retraining options
There are a few required settings to run retraining aside from your training dataset: output directory for the model, and the model architecture. Use HParams
to specify the export_dir
parameter for the output directory. Use the SupportedModels
class to specify the model architecture. The object detector solution supports the following model architectures:
MobileNet-V2
MobileNet-MultiHW-AVG
For more advanced customization of training parameters, see the Hyperparameters section below.
To set the required parameters, use the following code:
spec = object_detector.SupportedModels.MOBILENET_MULTI_AVG
hparams = object_detector.HParams(export_dir='exported_model')
options = object_detector.ObjectDetectorOptions(
supported_model=spec,
hparams=hparams
)
Run retraining
With your training dataset and retraining options prepared, you are ready to start the retraining process. This process is resource intensive and can take a few minutes to a few hours depending on your available compute resources. Using a Google Colab environment with standard GPU runtimes, the example retraining below takes about 2~4 minutes.
To begin the retraining process, use the create()
method with dataset and options you previously defined:
model = object_detector.ObjectDetector.create(
train_data=train_data,
validation_data=validation_data,
options=options)
Evaluate the model performance
After training the model, evaluate it on validation dataset and print the loss and coco_metrics. The most important metric for evaluating the model performance is typically the "AP" coco metric for Average Precision.
loss, coco_metrics = model.evaluate(validation_data, batch_size=4)
print(f"Validation loss: {loss}")
print(f"Validation coco metrics: {coco_metrics}")
Export model
After creating the model, convert and export it to a Tensorflow Lite model format for later use on an on-device application. The export also includes model metadata, which includes the label map.
model.export_model()
!ls exported_model
files.download('exported_model/model.tflite')
Model quantization
Model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy.
This section of the guide explains how to apply quantization to your model. Model Maker supports two forms of quantization for object detector:
- Quantization Aware Training: 8 bit integer precision for CPU usage
- Post-Training Quantization: 16 bit floating point precision for GPU usage
Quantization aware training (int8 quantization)
Quantization aware training (QAT) is a fine-tuning step which happens after fully training your model. This technique further tunes a model which emulates inference time quantization in order to account for the lower precision of 8 bit integer quantization. For on-device applications with a standard CPU, use Int8 precision. For more information, see the TensorFlow Lite documentation.
To apply quantization aware training and export to an int8 model, create a QATHParams
configuration and run the quantization_aware_training
method. See the Hyperparameters section below on detailed usage of QATHParams
.
qat_hparams = object_detector.QATHParams(learning_rate=0.3, batch_size=4, epochs=10, decay_steps=6, decay_rate=0.96)
model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)
qat_loss, qat_coco_metrics = model.evaluate(validation_data)
print(f"QAT validation loss: {qat_loss}")
print(f"QAT validation coco metrics: {qat_coco_metrics}")
The QAT step often requires multiple runs to tune the parameters of training. To avoid having to rerun model training using the create
method, use the restore_float_ckpt
method to restore the model state back to the fully trained float model(After running the create
method) in order to run QAT again.
new_qat_hparams = object_detector.QATHParams(learning_rate=0.9, batch_size=4, epochs=15, decay_steps=5, decay_rate=0.96)
model.restore_float_ckpt()
model.quantization_aware_training(train_data, validation_data, qat_hparams=new_qat_hparams)
qat_loss, qat_coco_metrics = model.evaluate(validation_data)
print(f"QAT validation loss: {qat_loss}")
print(f"QAT validation coco metrics: {qat_coco_metrics}")
Finally, us the export_model
to export to an int8 quantized model. The export_model
function will automatically export to either float32 or int8 model depending on whether quantization_aware_training
was run.
model.export_model('model_int8_qat.tflite')
!ls -lh exported_model
files.download('exported_model/model_int8_qat.tflite')
Post-training quantization (fp16 quantization)
Post-training model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy. This approach reduces the size of the data processed by the model, for example by transforming 32-bit floating point numbers to 16-bit floats. Float16 quantization is reccomended for GPU usage. For more information, see the TensorFlow Lite documentation.
First, import the MediaPipe Model Maker quantization module:
from mediapipe_model_maker import quantization
Define a QuantizationConfig object using the for_float16()
class method. This configuration modifies a trained model to use 16-bit floating point numbers instead of 32-bit floating point numbers. You can further customize the quantization process by setting additional parameters for the QuantizationConfig class.
quantization_config = quantization.QuantizationConfig.for_float16()
Export the model using the additional quantization_config object to apply post-training quantization. Note that if you previously ran quantization_aware_training
, you must first convert the model back to a float model by using restore_float_ckpt
.
model.restore_float_ckpt()
model.export_model(model_name="model_fp16.tflite", quantization_config=quantization_config)
!ls -lh exported_model
files.download('exported_model/model_fp16.tflite')
Hyperparameters
You can further customize the model using the ObjectDetectorOptions class, which has three parameters for SupportedModels
, ModelOptions
, and HParams
.
Use the SupportedModels
enum class to specify the model architecture to use for training. The following model architectures are supported:
- MOBILENET_V2
- MOBILENET_V2_I320
- MOBILENET_MULTI_AVG
- MOBILENET_MULTI_AVG_I384
Use the HParams
class to customize other parameters related to training and saving the model:
learning_rate
: Learning rate to use for gradient descent training. Defaults to 0.3.batch_size
: Batch size for training. Defaults to 8.epochs
: Number of training iterations over the dataset. Defaults to 30.cosine_decay_epochs
: The number of epochs for cosine decay learning rate. See tf.keras.optimizers.schedules.CosineDecay for more info. Defaults to None, which is equivalent to setting it toepochs
.cosine_decay_alpha
: The alpha value for cosine decay learning rate. See tf.keras.optimizers.schedules.CosineDecay for more info. Defaults to 1.0, which means no cosine decay.
Use the ModelOptions
class to customize parameters related to the model itself:
l2_weight_decay
: L2 regularization penalty used in tf.keras.regularizers.L2. Defaults to 3e-5.
Uset the QATHParams
class to customize training parameters for Quantization Aware Training:
learning_rate
: Learning rate to use for gradient descent QAT. Defaults to 0.3.batch_size
: Batch size for QAT. Defaults to 8epochs
: Number of training iterations over the dataset. Defaults to 15.decay_steps
: Learning rate decay steps for Exponential Decay. See tf.keras.optimizers.schedules.ExponentialDecay for more information. Defaults to 8decay_rate
: Learning rate decay rate for Exponential Decay. See tf.keras.optimizers.schedules.ExponentialDecay for more information. Defaults to 0.96.
Benchmarking
Below is a summary of our benchmarking results for the supported model architectures. These models were trained and evaluated on the same android figurines dataset as this notebook. When considering the model benchmarking results, there are a few important caveats to keep in mind:
- The android figurines dataset is a small and simple dataset with 62 training examples and 10 validation examples. Since the dataset is quite small, metrics may vary drastically due to variances in the training process. This dataset was provided for demo purposes and it is recommended to collect more data samples for better performing models.
- The float32 models were trained with the default HParams, and the QAT step for the int8 models was run with
QATHParams(learning_rate=0.1, batch_size=4, epochs=30, decay_rate=1)
. - For your own dataset, you will likely need to tune values for both HParams and QATHParams in order to achieve the best results. See the Hyperparameters section above for more information on configuring training parameters.
- All latency numbers are benchmarked on the Pixel 6.
Model architecture | Input Image Size | Test AP | CPU Latency | Model Size | |||
---|---|---|---|---|---|---|---|
float32 | QAT int8 | float32 | QAT int8 | float32 | QAT int8 | ||
MobileNetV2 | 256x256 | 88.4% | 73.5% | 48ms | 16ms | 11MB | 3.2MB |
MobileNetV2 I320 | 320x320 | 89.1% | 75.5% | 75ms | 33.38ms | 10MB | 3.3MB |
MobileNet MultiHW AVG | 256x256 | 88.5% | 70.0% | 56ms | 19ms | 13MB | 3.6MB |
MobileNet MultiHW AVG I384 | 384x384 | 92.7% | 73.4% | 238ms | 41ms | 13MB | 3.6MB |