将 PyTorch 模型转换为 LiteRT

AI Edge Torch 是一个库,可让您将 PyTorch 模型转换为 .tflite 格式,以便您使用 LiteRT 和 MediaPipe 运行这些模型。对于开发完全在设备端运行模型的移动应用的开发者而言,这尤其有用。AI Edge Torch 提供广泛的 CPU 覆盖率,并提供初始 GPU 和 NPU 支持。

如需开始将 PyTorch 模型转换为 LiteRT,请使用 Pytorch 转换器快速入门。如需了解详情,请参阅 AI Edge Torch GitHub 代码库

如果您要专门转换大语言模型 (LLM) 或基于转换器的模型,请使用 Generative Torch API,它可处理转换器专用转换详细信息,例如模型创作和量化。

转化工作流

以下步骤演示了将 PyTorch 模型简单地端到端转换为 LiteRT。

导入 AI Edge Torch

首先,导入 AI Edge Torch (ai-edge-torch) pip 软件包以及 PyTorch。

import ai_edge_torch
import torch

对于本例,我们还需要以下软件包:

import numpy
import torchvision

初始化和转换模型

我们将转换ResNet18,这是一个广受欢迎的图片识别模型。

resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()

使用 AI Edge Torch 库中的 convert 方法转换 PyTorch 模型。

sample_input = (torch.randn(1, 3, 224, 224),)
edge_model = ai_edge_torch.convert(resnet18.eval(), sample_input)

使用模型

转换 Pytorch 模型后,您可以使用新转换的 LiteRT 模型运行推理。

output = edge_model(*sample_inputs)

您可以以 .tflite 格式导出并保存转换后的模型,以供日后使用。

edge_model.export('resnet.tflite')