Copyright 2024 The AI Edge Authors.
Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Run in Google Colab | View source on GitHub | Download notebook |
When deploying LiteRT machine learning model to device or mobile app, you may want to enable the model to be improved or personalized based on input from the device or end user. Using on-device training techniques allows you to update a model without data leaving your users' devices, improving user privacy, and without requiring users to update the device software.
For example, you may have a model in your mobile app that recognizes fashion items, but you want users to get improved recognition performance over time based on their interests. Enabling on-device training allows users who are interested in shoes to get better at recognizing a particular style of shoe or shoe brand the more often they use your app.
This tutorial shows you how to construct a LiteRT model that can be incrementally trained and improved within an installed Android app.
Setup
This tutorial uses Python to train and convert a TensorFlow model before incorporating it into an Android app. Get started by installing and importing the following packages.
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
print("TensorFlow version:", tf.__version__)
TensorFlow version: 2.8.0
Classify images of clothing
This example code uses the Fashion MNIST dataset to train a neural network model for classifying images of clothing. This dataset contains 60,000 small (28 x 28 pixel) grayscale images containing 10 different categories of fashion accessories, including dresses, shirts, and sandals.
You can explore this dataset in more depth in the Keras classification tutorial.
Build a model for on-device training
LiteRT models typically have only a single exposed function method (or signature) that allows you to call the model to run an inference. For a model to be trained and used on a device, you must be able to perform several separate operations, including train, infer, save, and restore functions for the model. You can enable this functionality by first extending your TensorFlow model to have multiple functions, and then exposing those functions as signatures when you convert your model to the LiteRT model format.
The code example below shows you how to add the following functions to a TensorFlow model:
train
function trains the model with training data.infer
function invokes the inference.save
function saves the trainable weights into the file system.restore
function loads the trainable weights from the file system.
IMG_SIZE = 28
class Model(tf.Module):
def __init__(self):
self.model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(IMG_SIZE, IMG_SIZE), name='flatten'),
tf.keras.layers.Dense(128, activation='relu', name='dense_1'),
tf.keras.layers.Dense(10, name='dense_2')
])
self.model.compile(
optimizer='sgd',
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True))
# The `train` function takes a batch of input images and labels.
@tf.function(input_signature=[
tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
tf.TensorSpec([None, 10], tf.float32),
])
def train(self, x, y):
with tf.GradientTape() as tape:
prediction = self.model(x)
loss = self.model.loss(y, prediction)
gradients = tape.gradient(loss, self.model.trainable_variables)
self.model.optimizer.apply_gradients(
zip(gradients, self.model.trainable_variables))
result = {"loss": loss}
return result
@tf.function(input_signature=[
tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
])
def infer(self, x):
logits = self.model(x)
probabilities = tf.nn.softmax(logits, axis=-1)
return {
"output": probabilities,
"logits": logits
}
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def save(self, checkpoint_path):
tensor_names = [weight.name for weight in self.model.weights]
tensors_to_save = [weight.read_value() for weight in self.model.weights]
tf.raw_ops.Save(
filename=checkpoint_path, tensor_names=tensor_names,
data=tensors_to_save, name='save')
return {
"checkpoint_path": checkpoint_path
}
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def restore(self, checkpoint_path):
restored_tensors = {}
for var in self.model.weights:
restored = tf.raw_ops.Restore(
file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype,
name='restore')
var.assign(restored)
restored_tensors[var.name] = restored
return restored_tensors
The train
function in the code above uses the GradientTape class to record operations for automatic differentiation. For more information on how to use this class, see the Introduction to gradients and automatic differentiation.
You could use the Model.train_step
method of the keras model here instead of a from-scratch implementation. Just note that the loss (and metrics) returned by Model.train_step
is the running average, and should be reset regularly (typically each epoch). See Customize Model.fit for details.
Prepare the data
Get the Fashion MNIST dataset for training your model.
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
Preprocess the dataset
Pixel values in this dataset are between 0 and 255, and must be normalized to a value between 0 and 1 for processing by the model. Divide the values by 255 to make this adjustment.
train_images = (train_images / 255.0).astype(np.float32)
test_images = (test_images / 255.0).astype(np.float32)
Convert the data labels to categorical values by performing one-hot encoding.
train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)
Train the model
Before converting and setting up your LiteRT model, complete the initial training of your model using the preprocessed dataset and the train
signature method. The following code runs model training for 100 epochs, processing batches of 100 images at a time, and displaying the loss value after every 10 epochs. Since this training run is processing quite a bit of data, it may take a few minutes to finish.
NUM_EPOCHS = 100
BATCH_SIZE = 100
epochs = np.arange(1, NUM_EPOCHS + 1, 1)
losses = np.zeros([NUM_EPOCHS])
m = Model()
train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_ds = train_ds.batch(BATCH_SIZE)
for i in range(NUM_EPOCHS):
for x,y in train_ds:
result = m.train(x, y)
losses[i] = result['loss']
if (i + 1) % 10 == 0:
print(f"Finished {i+1} epochs")
print(f" loss: {losses[i]:.3f}")
# Save the trained weights to a checkpoint.
m.save('/tmp/model.ckpt')
Finished 10 epochs loss: 0.428 Finished 20 epochs loss: 0.378 Finished 30 epochs loss: 0.344 Finished 40 epochs loss: 0.317 Finished 50 epochs loss: 0.299 Finished 60 epochs loss: 0.283 Finished 70 epochs loss: 0.266 Finished 80 epochs loss: 0.252 Finished 90 epochs loss: 0.240 Finished 100 epochs loss: 0.230 {'checkpoint_path': <tf.Tensor: shape=(), dtype=string, numpy=b'/tmp/model.ckpt'>}
plt.plot(epochs, losses, label='Pre-training')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch')
plt.ylabel('Loss [Cross Entropy]')
plt.legend();
Convert model to LiteRT format
After you have extended your TensorFlow model to enable additional functions for on-device training and completed initial training of the model, you can convert it to LiteRT format. The following code converts and saves your model to that format, including the set of signatures that you use with the LiteRT model on a device: train, infer, save, restore
.
SAVED_MODEL_DIR = "saved_model"
tf.saved_model.save(
m,
SAVED_MODEL_DIR,
signatures={
'train':
m.train.get_concrete_function(),
'infer':
m.infer.get_concrete_function(),
'save':
m.save.get_concrete_function(),
'restore':
m.restore.get_concrete_function(),
})
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable LiteRT ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()
Setup the LiteRT signatures
The LiteRT model you saved in the previous step contains several function signatures. You can access them through the tf.lite.Interpreter
class and invoke each restore
, train
, save
, and infer
signature separately.
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
infer = interpreter.get_signature_runner("infer")
Compare the output of the original model, and the converted lite model:
logits_original = m.infer(x=train_images[:1])['logits'][0]
logits_lite = infer(x=train_images[:1])['logits'][0]
def compare_logits(logits):
width = 0.35
offset = width/2
assert len(logits)==2
keys = list(logits.keys())
plt.bar(x = np.arange(len(logits[keys[0]]))-offset,
height=logits[keys[0]], width=0.35, label=keys[0])
plt.bar(x = np.arange(len(logits[keys[1]]))+offset,
height=logits[keys[1]], width=0.35, label=keys[1])
plt.legend()
plt.grid(True)
plt.ylabel('Logit')
plt.xlabel('ClassID')
delta = np.sum(np.abs(logits[keys[0]] - logits[keys[1]]))
plt.title(f"Total difference: {delta:.3g}")
compare_logits({'Original': logits_original, 'Lite': logits_lite})
Above, you can see that the behavior of the model is not changed by the conversion to TFLite.
Retrain the model on a device
After converting your model to LiteRT and deploying it with your app, you can retrain the model on a device using new data and the train
signature method of your model. Each training run generates a new set of weights that you can save for re-use and further improvement of the model, as shown in the next section.
On Android, you can perform on-device training with LiteRT using either Java or C++ APIs. In Java, use the Interpreter
class to load a model and drive model training tasks. The following example shows how to run the training procedure using the runSignature
method:
try (Interpreter interpreter = new Interpreter(modelBuffer)) {
int NUM_EPOCHS = 100;
int BATCH_SIZE = 100;
int IMG_HEIGHT = 28;
int IMG_WIDTH = 28;
int NUM_TRAININGS = 60000;
int NUM_BATCHES = NUM_TRAININGS / BATCH_SIZE;
List<FloatBuffer> trainImageBatches = new ArrayList<>(NUM_BATCHES);
List<FloatBuffer> trainLabelBatches = new ArrayList<>(NUM_BATCHES);
// Prepare training batches.
for (int i = 0; i < NUM_BATCHES; ++i) {
FloatBuffer trainImages = FloatBuffer.allocateDirect(BATCH_SIZE * IMG_HEIGHT * IMG_WIDTH).order(ByteOrder.nativeOrder());
FloatBuffer trainLabels = FloatBuffer.allocateDirect(BATCH_SIZE * 10).order(ByteOrder.nativeOrder());
// Fill the data values...
trainImageBatches.add(trainImages.rewind());
trainImageLabels.add(trainLabels.rewind());
}
// Run training for a few steps.
float[] losses = new float[NUM_EPOCHS];
for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", trainImageBatches.get(batchIdx));
inputs.put("y", trainLabelBatches.get(batchIdx));
Map<String, Object> outputs = new HashMap<>();
FloatBuffer loss = FloatBuffer.allocate(1);
outputs.put("loss", loss);
interpreter.runSignature(inputs, outputs, "train");
// Record the last loss.
if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);
}
// Print the loss output for every 10 epochs.
if ((epoch + 1) % 10 == 0) {
System.out.println(
"Finished " + (epoch + 1) + " epochs, current loss: " + loss.get(0));
}
}
// ...
}
You can see a complete code example of model retraining inside an Android app in the model personalization demo app.
Run training for a few epochs to improve or personalize the model. In practice, you would run this additional training using data collected on the device. For simplicity, this example uses the same training data as the previous training step.
train = interpreter.get_signature_runner("train")
NUM_EPOCHS = 50
BATCH_SIZE = 100
more_epochs = np.arange(epochs[-1]+1, epochs[-1] + NUM_EPOCHS + 1, 1)
more_losses = np.zeros([NUM_EPOCHS])
for i in range(NUM_EPOCHS):
for x,y in train_ds:
result = train(x=x, y=y)
more_losses[i] = result['loss']
if (i + 1) % 10 == 0:
print(f"Finished {i+1} epochs")
print(f" loss: {more_losses[i]:.3f}")
Finished 10 epochs loss: 0.223 Finished 20 epochs loss: 0.216 Finished 30 epochs loss: 0.210 Finished 40 epochs loss: 0.204 Finished 50 epochs loss: 0.198
plt.plot(epochs, losses, label='Pre-training')
plt.plot(more_epochs, more_losses, label='On device')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch')
plt.ylabel('Loss [Cross Entropy]')
plt.legend();
Above you can see that the on-device training picks up exactly where the pretraining stopped.
Save the trained weights
When you complete a training run on a device, the model updates the set of weights it is using in memory. Using the save
signature method you created in your LiteRT model, you can save these weights to a checkpoint file for later reuse and improve your model.
save = interpreter.get_signature_runner("save")
save(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))
{'checkpoint_path': array(b'/tmp/model.ckpt', dtype=object)}
In your Android application, you can store the generated weights as a checkpoint file in the internal storage space allocated for your app.
try (Interpreter interpreter = new Interpreter(modelBuffer)) {
// Conduct the training jobs.
// Export the trained weights as a checkpoint file.
File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
Map<String, Object> inputs = new HashMap<>();
inputs.put("checkpoint_path", outputFile.getAbsolutePath());
Map<String, Object> outputs = new HashMap<>();
interpreter.runSignature(inputs, outputs, "save");
}
Restore the trained weights
Any time you create an interpreter from a TFLite model, the interpreter will initially load the original model weights.
So after you've done some training and saved a checkpoint file, you'll need to run the restore
signature method to load the checkpoint.
A good rule is "Anytime you create an Interpreter for a model, if the checkpoint exists, load it". If you need to reset the model to the baseline behavior, just delete the checkpoint and create a fresh interpreter.
another_interpreter = tf.lite.Interpreter(model_content=tflite_model)
another_interpreter.allocate_tensors()
infer = another_interpreter.get_signature_runner("infer")
restore = another_interpreter.get_signature_runner("restore")
logits_before = infer(x=train_images[:1])['logits'][0]
# Restore the trained weights from /tmp/model.ckpt
restore(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))
logits_after = infer(x=train_images[:1])['logits'][0]
compare_logits({'Before': logits_before, 'After': logits_after})
The checkpoint was generated by training and saving with TFLite. Above you can see that applying the checkpoint updates the behavior of the model.
In your Android app, you can restore the serialized, trained weights from the checkpoint file you stored earlier.
try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
// Load the trained weights from the checkpoint file.
File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
Map<String, Object> inputs = new HashMap<>();
inputs.put("checkpoint_path", outputFile.getAbsolutePath());
Map<String, Object> outputs = new HashMap<>();
anotherInterpreter.runSignature(inputs, outputs, "restore");
}
Run Inference using trained weights
Once you have loaded previously saved weights from a checkpoint file, running the infer
method uses those weights with your original model to improve predictions. After loading the saved weights, you can use the infer
signature method as shown below.
infer = another_interpreter.get_signature_runner("infer")
result = infer(x=test_images)
predictions = np.argmax(result["output"], axis=1)
true_labels = np.argmax(test_labels, axis=1)
result['output'].shape
(10000, 10)
Plot the predicted labels.
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
def plot(images, predictions, true_labels):
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(images[i], cmap=plt.cm.binary)
color = 'b' if predictions[i] == true_labels[i] else 'r'
plt.xlabel(class_names[predictions[i]], color=color)
plt.show()
plot(test_images, predictions, true_labels)
predictions.shape
(10000,)
In your Android application, after restoring the trained weights, run the inferences based on the loaded data.
try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
// Restore the weights from the checkpoint file.
int NUM_TESTS = 10;
FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());
// Fill the test data.
// Run the inference.
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", testImages.rewind());
Map<String, Object> outputs = new HashMap<>();
outputs.put("output", output);
anotherInterpreter.runSignature(inputs, outputs, "infer");
output.rewind();
// Process the result to get the final category values.
int[] testLabels = new int[NUM_TESTS];
for (int i = 0; i < NUM_TESTS; ++i) {
int index = 0;
for (int j = 1; j < 10; ++j) {
if (output.get(i * 10 + index) < output.get(i * 10 + j)) index = testLabels[j];
}
testLabels[i] = index;
}
}
Congratulations! You now have built a LiteRT model that supports on-device training. For more coding details, check out the example implementation in the model personalization demo app.
If you are interested in learning more about image classification, check Keras classification tutorial in the TensorFlow official guide page. This tutorial is based on that exercise and provides more depth on the subject of classification.