PyTorch-Modelle in TF Lite konvertieren

AI Edge Torch ist eine Bibliothek, mit der Sie PyTorch-Modelle in ein .tflite-Format konvertieren können, um diese Modelle mit TensorFlow Lite und MediaPipe auszuführen. Dies ist besonders hilfreich für Entwickler, die mobile Apps erstellen, bei denen Modelle vollständig auf dem Gerät ausgeführt werden. AI Edge Torch bietet eine breite CPU-Abdeckung mit anfänglicher Unterstützung für GPU und NPU.

Informationen zum Konvertieren von PyTorch-Modellen in TF Lite finden Sie in der Kurzanleitung zum Pytorch-Konverter. Weitere Informationen finden Sie im GitHub-Repository zu AI Edge Torch.

Wenn Sie speziell Large Language Models (LLMs) oder Transformer-basierte Modelle konvertieren, sollten Sie die Generative Torch API verwenden, die transformatorspezifische Konvertierungsdetails wie Modellerstellung und Quantisierung verarbeitet.

Konvertierungs-Workflow

Mit den folgenden Schritten wird eine einfache End-to-End-Konvertierung eines PyTorch-Modells in TensorFlow Lite veranschaulicht.

AI Edge Torch importieren

Importieren Sie zuerst das pip-Paket von AI Edge Torch (ai-edge-torch) zusammen mit PyTorch.

import ai_edge_torch
import torch

Für dieses Beispiel sind außerdem die folgenden Pakete erforderlich:

import numpy
import torchvision

Modell initialisieren und konvertieren

Wir konvertieren das beliebte Bilderkennungsmodell ResNet18.

resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()

Verwenden Sie die Methode convert aus der AI Edge Torch-Bibliothek, um das PyTorch-Modell zu konvertieren.

sample_input = (torch.randn(1, 3, 224, 224),)
edge_model = ai_edge_torch.convert(resnet18.eval(), sample_input)

Modell verwenden

Nach der Konvertierung des Pytorch-Modells können Sie Inferenzen mit dem neuen konvertierten TF Lite-Modell ausführen.

output = edge_model(*sample_inputs)

Sie können das konvertierte Modell zur späteren Verwendung im .tflite-Format exportieren und speichern.

edge_model.export('resnet.tflite')