Convertir des modèles PyTorch en modèles TF Lite

AI Edge Torch est une bibliothèque qui vous permet de convertir des modèles PyTorch au format .tflite, ce qui vous permet de les exécuter avec TensorFlow Lite et MediaPipe. Cela est particulièrement utile pour les développeurs qui créent des applications mobiles qui exécutent des modèles entièrement sur l'appareil. AI Edge Torch offre une couverture de processeur étendue, avec une compatibilité initiale avec les GPU et les NPU.

Pour commencer à convertir des modèles PyTorch au format TF Lite, suivez le guide de démarrage rapide du convertisseur PyTorch. Pour en savoir plus, consultez le dépôt GitHub d'AI Edge Torch.

Si vous convertissez spécifiquement des grands modèles de langage (LLM) ou des modèles basés sur des transformateurs, utilisez l'API Generative Torch, qui gère les détails de conversion spécifiques aux transformateurs, tels que la création et la quantification du modèle.

Workflow de conversion

Les étapes suivantes illustrent une conversion simple de bout en bout d'un modèle PyTorch en TensorFlow Lite.

Importer AI Edge Torch

Commencez par importer le package pip AI Edge Torch (ai-edge-torch) avec PyTorch.

import ai_edge_torch
import torch

Pour cet exemple, nous avons également besoin des packages suivants:

import numpy
import torchvision

Initialiser et convertir le modèle

Nous allons convertir ResNet18, un modèle de reconnaissance d'image populaire.

resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()

Utilisez la méthode convert de la bibliothèque AI Edge Torch pour convertir le modèle PyTorch.

sample_input = (torch.randn(1, 3, 224, 224),)
edge_model = ai_edge_torch.convert(resnet18.eval(), sample_input)

Utiliser le modèle

Après avoir converti le modèle Pytorch, vous pouvez exécuter des inférences avec le nouveau modèle TF Lite converti.

output = edge_model(*sample_inputs)

Vous pouvez exporter et enregistrer le modèle converti au format .tflite pour une utilisation ultérieure.

edge_model.export('resnet.tflite')