Image classification model customization guide

Colab logo Run in Colab GitHub logo View on GitHub

License information

The MediaPipe Model Maker package is a low-code solution for customizing on-device machine learning (ML) Models.

The MediaPipe image classification solution provides several models you can use immediately for machine learning (ML) in your application. However, if you need to classify images with content not covered by the provided models, you can customize any of the provided models with your own data and MediaPipe Model Maker. This model modification tool rebuilds a portion of the model using data you provide. This method is faster than training a new model and can produce a model that is more useful for your specific application.

The following sections show you how to use Model Maker to retrain a pre-built model for image classification with your own data, which you can then use with the MediaPipe Image Classifier. The example retrains a general purpose classification model to classify images of flowers.

This notebook shows the end-to-end process of customizing an ImageNet pretrained image classification model for recognizing flowers defined in a user customized flower dataset.


This section describes key steps for setting up your development environment to retrain a model. These instructions describe how to update a model using Google Colab, and you can also use Python in your own development environment. For general information on setting up your development environment for using MediaPipe, including platform version requirements, see the Setup guide for Python.

To install the libraries for customizing a model, run the following commands:

python --version
pip install --upgrade pip
pip install mediapipe-model-maker

Use the following code to import the required Python classes:

from google.colab import files
import os
import tensorflow as tf
assert tf.__version__.startswith('2')

from mediapipe_model_maker import image_classifier

import matplotlib.pyplot as plt

Prepare data

Retraining a model for image classification requires a dataset that includes all kinds of items, or classes, that you want the completed model to be able to identify. You can do this by trimming down a public dataset to only the classes that are relevant to your usecase, compiling your own data, or some combination of both. The dataset can be significantly smaller that what would be required to train a new model. For example, the ImageNet dataset used to train many reference models contains millions of images with thousands of categories. Transfer learning with Model Maker can retrain an existing model with a smaller dataset and still perform well, depending on your inference accuracy goals. These instructions use a smaller dataset containing 5 types of flowers, or 5 classes, with 600 to 800 images per class.

To download the example dataset, use the following code:

image_path = tf.keras.utils.get_file(
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')

This code stores the downloaded images at the directory location saved in the image_path variable. That directory contains several subdirectories, each corresponding to specific class labels. Your training data should also follow this pattern: <image_path>/<label_name>/<image_names>.*.

Review data

When preparing data for training with Model Maker, you should review the training data to make sure it is in the proper format, correctly classified, and organized in directories corresponding to classification labels. This step is optional, but recommended.

The following code block retrieves all the label names from the expected directory structure at image_path and prints them.

labels = []
for i in os.listdir(image_path):
  if os.path.isdir(os.path.join(image_path, i)):

You can review a few of the example images from each category using the following code:


for label in labels:
  label_dir = os.path.join(image_path, label)
  example_filenames = os.listdir(label_dir)[:NUM_EXAMPLES]
  fig, axs = plt.subplots(1, NUM_EXAMPLES, figsize=(10,2))
  for i in range(NUM_EXAMPLES):
    axs[i].imshow(plt.imread(os.path.join(label_dir, example_filenames[i])))
  fig.suptitle(f'Showing {NUM_EXAMPLES} examples for {label}')

Create dataset

Training data for machine learning can be large, consisting of hundreds or thousands of files which typically do not fit into available memory. You must also split it into groups for different uses: training, testing, and validation. For these reasons, Model Maker uses a Dataset class to organize training data and feed it to the retraining process.

To create a dataset, use the Dataset.from_folder method to load the data located at image_path and split it into training, testing, and validation groups:

data = image_classifier.Dataset.from_folder(image_path)
train_data, remaining_data = data.split(0.8)
test_data, validation_data = remaining_data.split(0.5)

In this example, 80% of the data is used for training, with the remaining data split in half, so that 10% of the total is used for testing, and 10% for validation.

Retrain model

Once you have completed preparing your data, you can begin retraining a model to build a new classification layer that can recognize the items types, or classes, defined by your training data. This type of model modification is called transfer learning. The instructions below use the data prepared in the previous section to retrain an image classification model to recognize different types of flowers.

Set retraining options

There are a few required settings to run a retraining aside from your training dataset: output directory for the model and the model architecture. Use HParams object export_dir parameter to specify a model output directory. Use the SupportedModels class to specify the model architecture. The image classifier solution supports the following model architectures:

  • MobileNet-V2
  • EfficientNet-Lite0
  • EfficientNet-Lite2
  • EfficientNet-Lite4

To set the required parameters, use the following code:

spec = image_classifier.SupportedModels.MOBILENET_V2
hparams = image_classifier.HParams(export_dir="exported_model")
options = image_classifier.ImageClassifierOptions(supported_model=spec, hparams=hparams)

This example code uses MobileNetV2 model architecture, which you can learn more about from the MobileNetV2 research paper. The retraining process has many additional options, however most of them are set for you automatically. You can learn about these optional parameters in the Retraining parameters section.

Run retraining

With your training dataset and retraining options prepared, you are ready to start the retraining process. This process is resource intensive and can take a few minutes to a few hours depending on your available compute resources. Using a Google Colab environment with standard CPU processing, the example retraining below takes about 20 minutes to train on approximately 4000 images. You can typically decrease your training time by using GPU processors.

To begin the retraining process, use the create() method with dataset and options you previously defined:

model = image_classifier.ImageClassifier.create(
    train_data = train_data,
    validation_data = validation_data,

Evaluate performance

After retraining the model, you should evaluate it on a test dataset, which is typically a portion of your original dataset not used during training. Accuracy levels between 0.8 and 0.9 are generally considered very good, but your use case requirements may differ. You should also consider how fast the model can produce an inference. Higher accuracy frequently comes at the cost of longer inference times.

To run an evaluation of the example model, run it against the test portion of the dataset:

loss, acc = model.evaluate(test_data)
print(f'Test loss:{loss}, Test accuracy:{acc}')

Export model

After retraining a model, you must export it to Tensorflow Lite model format to use it with the MediaPipe in your application. The export process generates required model metadata, as well as a classification label file.

To export the retrained model for use in your application, use the following command:


Use the follow commands with Google Colab to list model and download it to your development environment:

!ls exported_model'exported_model/model.tflite')

Model tuning

You can use the MediaPipe Model Maker tool to further improve and adjust the model retraining with configuation options and performance techniques such as data quantization. These steps are optional. Model Maker uses reasonable default settings for all of the training configuration parameters, but if you want to further tune the model retraining, the instructions below describe the available options.

Retraining parameters

You can further customize how the retraining process runs to adjust training time and potentially increase the retrained model's performance. These parameters are optional. Use the ImageClassifierModelOptions class and the HParams class to set these additional options.

Use the ImageClassifierModelOptions class parameters to customize the existing model. It has the following customizable parameter that affects model accuracy:

  • dropout_rate: The fraction of the input units to drop. Used in dropout layer. Defaults to 0.05.

Use the HParams class to customize other parameters related to training and saving the model:

  • learning_rate: The learning rate to use for gradient descent training. Defaults to 0.001.
  • batch_size: Batch size for training. Defaults to 2.
  • epochs: Number of training iterations over the dataset. Defaults to 10.
  • steps_per_epoch: An optional integer that indicates the number of training steps per epoch. If not set, the training pipeline calculates the default steps per epoch as the training dataset size divided by batch size.
  • shuffle: True if the dataset is shuffled before training. Defaults to False.
  • do_fine_tuning: If true, the base module is trained together with the classification layer on top. This defaults to False, which means only the classification layer is trained and pre-trained weights for the base module are frozen.
  • l1_regularizer: A regularizer that applies a L1 regularization penalty. Defaults to 0.0.
  • l2_regularizer: A regularizer that applies a L2 regularization penalty. Defaults to 0.0001.
  • label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for more details. Defaults to 0.1.
  • do_data_augmentation: Whether or not the training dataset is augmented by applying random transformations such as cropping, flipping, etc. See utils.image_preprocessing for details. Defaults to True.
  • decay_samples: Number of training samples used to calculate the decay steps and create the training optimizer. Defaults to 2,560,000.
  • warmup_epochs: Number of warmup steps for a linear increasing warmup schedule on the learning rate. Used to set up warmup schedule by model_util.WarmUp. Defaults to 2.

The following example code trains a model with more epochs and a higher dropout rate:

hparams=image_classifier.HParams(epochs=15, export_dir="exported_model_2")
options = image_classifier.ImageClassifierOptions(supported_model=spec, hparams=hparams)
options.model_options = image_classifier.ModelOptions(dropout_rate = 0.07)
model_2 = image_classifier.ImageClassifier.create(
    train_data = train_data,
    validation_data = validation_data,

To evaluate the newly trained model above, use the following code:

loss, accuracy = model_2.evaluate(test_data)

For more information on the general performance of the supported models, refer to the Performance benchmarks section.

Model quantization

Post-training model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy. This approach reduces the size of the data processed by the model, for example by transforming 32-bit floating point numbers to 8-bit integers. This technique is widely used to further optimize models after the training process.

This section of the guide explains how to apply quantization to your retrained model. This optimization must be done as part of the Model Maker model export process, and cannot be performed on an exported model. The following example demonstrates how to use this approach to apply int8 quantization to a retrained model. For more information on post-training quantization, see the TensorFlow Lite documentation.

Import the MediaPipe Model Maker quantization module:

from mediapipe_model_maker import quantization

Define a QuantizationConfig object using the for_int8() class method. This configuration modifies a trained model to use 8-bit integers instead of larger data types, such as 32-bit floating point numbers. You can further customize the quantization process by setting additional parameters for the QuantizationConfig class.

quantization_config = quantization.QuantizationConfig.for_int8(train_data)

The for_int8() method requires a representative dataset, while the dynamic() and for_float16() quantization methods do not. The quantization process uses the representative dataset to perform model modification and you typically use your existing training dataset for this purpose. For more information on the process and options, see the TensorFlow Lite Post-training quantization guide.

Export the model using the additional quantization_config object to apply post-training quantization:

model.export_model(model_name="model_int8.tflite", quantization_config=quantization_config)

After running this command, you should have a new model_int8.tflite model file. This new, quantized model should be significantly smaller than the model.tflite file. You can compare the sizes using the following command:

ls -lh exported_model

Performance benchmarks

Below is a summary of our benchmarking results for the supported model architectures. These models were trained and evaluated on the same flowers dataset as this notebook. When considering the model benchmarking results, there are a few important caveats to keep in mind:

  • The test accuracy column reflects models which were trained with the default parameters. To optimize model performance, experiment with different model and retraining parameters in order to obtain the highest test accuracy. Refer to the Retraining parameters section for more information on customizing these settings.
  • The larger model architectures, such as EfficientNet_Lite4, may not acheive the highest test accuracy on simpler datasets like the flowers dataset used this notebook. Research suggests that these larger model architecture can outperform the others on more complex datasets like ImageNet, for more information, see EfficientNet paper. The ImageNet dataset is more complex, with over a million training images and 1000 classes, while the flowers dataset has only 3670 training images and 5 classses.
Model architecture Test Accuracy Model Size CPU 1 Thread Latency(Pixel 6) GPU Latency(Pixel 6) EdgeTPULatency(Pixel 6)
MobileNet_V2 85.4% 8.9MB 29.12 77.77 31.14
EfficientNet_Lite0 91.3% 13.5MB 15.6 9.25 16.72
EfficientNet_Lite2 91.5% 19.2MB 35.2 13.94 37.52
EfficientNet_Lite4 90.8% 46.8MB 103.16 23.14 114.67