Convert PyTorch models to TF Lite

AI Edge Torch is a library lets you convert PyTorch models into a .tflite format, enabling you to run those models with TensorFlow Lite and MediaPipe. This is especially helpful for developers creating mobile apps that run models completely on-device. AI Edge Torch offers broad CPU coverage, with initial GPU and NPU support.

To get started converting PyTorch models to TF Lite, use the Pytorch converter quickstart. For more information, see the AI Edge Torch GitHub repo.

If you are specifically converting Large Language Models (LLMs) or transformer-based models, use the Generative Torch API, which handles transformer-specific conversion details like model authoring and quantization.

Conversion workflow

The following steps demonstrate a simple end-to-end conversion of a PyTorch model to TensorFlow Lite.

Import AI Edge Torch

Start by importing the AI Edge Torch (ai-edge-torch) pip package, along with PyTorch.

import ai_edge_torch
import torch

For this example, we also require the following packages:

import numpy
import torchvision

Initialize and convert the model

We will convert ResNet18, a popular image recognition model.

resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()

Use the convert method from the AI Edge Torch library to convert the PyTorch model.

sample_input = (torch.randn(1, 3, 224, 224),)
edge_model = ai_edge_torch.convert(resnet18.eval(), sample_input)

Use the model

After converting the Pytorch model, you can run inferences with the new converted TF Lite model.

output = edge_model(*sample_inputs)

You can export and save the converted model in the .tflite format for future use.

edge_model.export('resnet.tflite')