Copyright 2024 The AI Edge Torch Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This Colab demonstrates how to convert a PyTorch model to the LiteRT format using the AI Edge Torch package. In this example, we will convert ResNet18, a popular image recognition model, into a LiteRT model that can be later applied to a LiteRT or MediaPipe app.
pip install ai-edge-torch-nightly torchvision
Import packages
The PyTorch converter is available in the AI Edge Torch GitHub repository. We also require the PyTorch library, as well as numpy
and torchvision
, which includes the ResNet18 model.
import ai_edge_torch
import numpy
import torch
import torchvision
Instantiate the PyTorch model
Let's instantiate resnet18
as a sample model from PyTorch's torchvision
package. We'll also provide it with a sample input and execute the model through PyTorch.
resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()
sample_inputs = (torch.randn(1, 3, 224, 224),)
torch_output = resnet18(*sample_inputs)
Convert the model to LiteRT
Use the convert
function from the ai_edge_torch
package, which converts PyTorch models to the LiteRT format. This will turn the PyTorch model into an on-device model, ready to use with LiteRT and MediaPipe. The conversion process requires a model's sample input for tracing and shape inference.
edge_model = ai_edge_torch.convert(resnet18.eval(), sample_inputs)
Inference
Get outputs from inference with the TFLite runtime by directly calling the edge_model with the inputs. Many of the details of TFLite inference in Python are abstracted away with this API.
edge_output = edge_model(*sample_inputs)
Validate the model
Make sure that the output generated by the new converted model matches the output generated by original PyTorch.
if (numpy.allclose(
torch_output.detach().numpy(),
edge_output,
atol=1e-5,
rtol=1e-5,
)):
print("Inference result with Pytorch and TfLite was within tolerance")
else:
print("Something wrong with Pytorch --> TfLite")
Serialization
The converted model includes an export
method, which you can use to serialize the model. This exports the model as a TFLite Flatbuffers file.
from google.colab import files
edge_model.export('resnet.tflite')
# Download the tflite flatbuffer which can be used with the existing TfLite APIs.
# files.download('resnet.tflite')
Visualization
The export
function creates a TFLite file, which is visualizable with the Google AI Edge Model Explorer.
pip install ai-edge-model-explorer
import model_explorer
model_explorer.visualize('resnet.tflite')