将 PyTorch 模型转换为 TF Lite

AI Edge Torch 是一个库,可让您将 PyTorch 模型转换为 .tflite 格式,使您能够使用 TensorFlow Lite 和 MediaPipe 运行这些模型。这对于创建完全在设备端运行模型的移动应用特别有帮助。AI Edge Torch 提供广泛的 CPU 覆盖率,并初步支持 GPU 和 NPU。

如需开始将 PyTorch 模型转换为 TF Lite,请参阅 Pytorch 转换器快速入门。如需了解详情,请参阅 AI Edge Torch GitHub 代码库

如果您要专门转换大型语言模型 (LLM) 或基于转换器的模型,请使用 Generative Torch API,该 API 可以处理特定于转换器的转换细节,如模型编写和量化。

转化工作流程

以下步骤演示了 PyTorch 模型到 TensorFlow Lite 的简单端到端转换。

导入 AI Edge Torch

首先导入 AI Edge Torch (ai-edge-torch) pip 软件包和 PyTorch。

import ai_edge_torch
import torch

在此示例中,我们还需要以下软件包:

import numpy
import torchvision

初始化和转换模型

我们将转换热门图像识别模型 ResNet18

resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()

使用 AI Edge Torch 库中的 convert 方法转换 PyTorch 模型。

sample_input = (torch.randn(1, 3, 224, 224),)
edge_model = ai_edge_torch.convert(resnet18.eval(), sample_input)

使用模型

转换 Pytorch 模型后,您可以使用转换后的新 TF Lite 模型运行推断。

output = edge_model(*sample_inputs)

您能够以 .tflite 格式导出并保存转换后的模型,以备将来使用。

edge_model.export('resnet.tflite')