Pytorch to TFLite quickstart

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

This Colab demonstrates how to convert a PyTorch model to the TF Lite format using the AI Edge Torch package. In this example, we will convert ResNet18, a popular image recognition model, into a TF Lite model that can be later applied to a TF Lite or MediaPipe app.

pip install -r https://github.com/google-ai-edge/ai-edge-torch/releases/download/v0.1.1/requirements.txt
pip install ai-edge-torch==0.1.1

Import packages

The PyTorch converter is available in the AI Edge Torch GitHub repository. We also require the PyTorch library, as well as numpy and torchvision, which includes the ResNet18 model.

import ai_edge_torch
import numpy
import torch
import torchvision

Instantiate the PyTorch model

Let's instantiate resnet18 as a sample model from PyTorch's torchvision package. We'll also provide it with a sample input and execute the model through PyTorch.

resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()
sample_inputs = (torch.randn(1, 3, 224, 224),)
torch_output = resnet18(*sample_inputs)

Convert the model to TF Lite

Use the convert function from the ai_edge_torch package, which converts PyTorch models to the TF Lite format. This will turn the PyTorch model into an on-device model, ready to use with TF Lite and MediaPipe. The conversion process requires a model's sample input for tracing and shape inference.

edge_model = ai_edge_torch.convert(resnet18.eval(), sample_inputs)

Inference

Get outputs from inference with the TFLite runtime by directly calling the edge_model with the inputs. Many of the details of TFLite inference in Python are abstracted away with this API.

edge_output = edge_model(*sample_inputs)

Validate the model

Make sure that the output generated by the new converted model matches the output generated by original PyTorch.

if (numpy.allclose(
    torch_output.detach().numpy(),
    edge_output,
    atol=1e-5,
    rtol=1e-5,
)):
    print("Inference result with Pytorch and TfLite was within tolerance")
else:
    print("Something wrong with Pytorch --> TfLite")

Serialization

The converted model includes an export method, which you can use to serialize the model. This exports the model as a TFLite Flatbuffers file.

from google.colab import files
edge_model.export('resnet.tflite')

# Download the tflite flatbuffer which can be used with the existing TfLite APIs.
# files.download('resnet.tflite')

Visualization

The export function creates a TFLite file, which is visualizable with the Google AI Edge Model Explorer.

pip install ai-edge-model-explorer
import model_explorer
model_explorer.visualize('resnet.tflite')