Guide d'inférence LLM pour Android

L'API LLM Inference vous permet d'exécuter de grands modèles de langage (LLM) entièrement sur l'appareil pour les applications Android, que vous pouvez utiliser pour effectuer un large éventail de tâches, telles que la génération de texte, la récupération d'informations sous forme de langage naturel et la synthèse de documents. Cette tâche est compatible avec plusieurs grands modèles de langage texte-vers-texte, ce qui vous permet d'appliquer les derniers modèles d'IA générative sur l'appareil à vos applications Android.

Cette tâche est compatible avec Gemma 2B, qui fait partie d'une famille de modèles ouverts légers et de pointe, construits à partir des mêmes recherches et technologies que celles utilisées pour créer les modèles Gemini. Il est également compatible avec les modèles externes suivants : Phi-2, Falcon-RW-1B et StableLM-3B, ainsi que tous les modèles exportés via AI Edge.

Pour en savoir plus sur les fonctionnalités, les modèles et les options de configuration de cette tâche, consultez la présentation.

Exemple de code

Ce guide fait référence à un exemple d'application de génération de texte de base pour Android. Vous pouvez utiliser l'application comme point de départ pour votre propre application Android ou vous y référer lorsque vous modifiez une application existante. L'exemple de code est hébergé sur GitHub.

Télécharger le code

Les instructions suivantes vous expliquent comment créer une copie locale de l'exemple de code à l'aide de l'outil de ligne de commande git.

Pour télécharger l'exemple de code, procédez comme suit:

  1. Clonez le dépôt git à l'aide de la commande suivante :
    git clone https://github.com/google-ai-edge/mediapipe-samples
    
  2. Vous pouvez éventuellement configurer votre instance Git pour utiliser le paiement creux afin de n'avoir que les fichiers de l'exemple d'application de l'API LLM Inference :
    cd mediapipe
    git sparse-checkout init --cone
    git sparse-checkout set examples/llm_inference/android
    

Après avoir créé une version locale de l'exemple de code, vous pouvez importer le projet dans Android Studio et exécuter l'application. Pour obtenir des instructions, consultez le guide de configuration pour Android.

Préparation

Cette section décrit les étapes clés de la configuration de votre environnement de développement et de vos projets de code spécifiquement pour l'utilisation de l'API LLM Inference. Pour obtenir des informations générales sur la configuration de votre environnement de développement pour l'utilisation des tâches MediaPipe, y compris sur les exigences de version de la plate-forme, consultez le guide de configuration pour Android.

Dépendances

L'API LLM Inference utilise la bibliothèque com.google.mediapipe:tasks-genai. Ajoutez la dépendance suivante au fichier build.gradle de votre application Android:

dependencies {
    implementation 'com.google.mediapipe:tasks-genai:0.10.14'
}

Modèle

L'API LLM Inference de MediaPipe nécessite un modèle de langage texte-vers-texte entraîné compatible avec cette tâche. Après avoir téléchargé un modèle, installez les dépendances requises, puis transférez-le sur l'appareil Android. Si vous utilisez un autre modèle que Gemma, vous devez le convertir dans un format compatible avec MediaPipe.

Pour en savoir plus sur les modèles entraînés disponibles pour l'API LLM Inference, consultez la section Modèles de la présentation des tâches.

Télécharger un modèle

Avant d'initialiser l'API LLM Inference, téléchargez l'un des modèles compatibles et stockez le fichier dans le répertoire de votre projet:

  • Gemma 2B : faisant partie d'une famille de modèles ouverts légers et de pointe, construits à partir des mêmes recherches et technologies que celles utilisées pour créer les modèles Gemini. Convient à diverses tâches de génération de texte, y compris la réponse à des questions, la synthèse et le raisonnement.
  • Phi-2: modèle Transformer de 2, 7 milliards de paramètres, adapté au format Question-Réponse, Chat et Code.
  • Falcon-RW-1B: modèle causal de 1 milliard de paramètres, basé uniquement sur un décodeur, entraîné sur 350 milliards de jetons RefinedWeb.
  • StableLM-3B: modèle de langage basé sur un décodeur comprenant trois milliards de paramètres, pré-entraîné sur 1 000 milliards de jetons issus de divers ensembles de données de code et en anglais.

Vous pouvez également utiliser des modèles mappés et exportés via AI Edge Troch.

Nous vous recommandons d'utiliser Gemma 2B, disponible sur Kaggle Models et disponible dans un format déjà compatible avec l'API LLM Inference. Si vous utilisez un autre LLM, vous devez convertir le modèle dans un format compatible avec MediaPipe. Pour en savoir plus sur Gemma 2B, consultez le site de Gemma. Pour en savoir plus sur les autres modèles disponibles, consultez la section Modèles de la présentation des tâches.

Convertir le modèle au format MediaPipe

Conversion par modèle natif

Si vous utilisez un LLM externe (Phi-2, Falcon ou StableLM) ou une version de Gemma autre que Kaggle, servez-vous de nos scripts de conversion pour mettre en forme le modèle afin qu'il soit compatible avec MediaPipe.

Le processus de conversion du modèle nécessite le package PyPI MediaPipe. Le script de conversion est disponible dans tous les packages MediaPipe après 0.10.11.

Installez et importez les dépendances avec la commande suivante:

$ python3 -m pip install mediapipe

Utilisez la bibliothèque genai.converter pour convertir le modèle:

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  input_ckpt=INPUT_CKPT,
  ckpt_format=CKPT_FORMAT,
  model_type=MODEL_TYPE,
  backend=BACKEND,
  output_dir=OUTPUT_DIR,
  combine_file_only=False,
  vocab_model_file=VOCAB_MODEL_FILE,
  output_tflite_file=OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

Pour convertir le modèle de LoRA, ConversionConfig doit spécifier les options du modèle de base ainsi que des options de LoRA supplémentaires. Notez que, comme l'API n'est compatible qu'avec l'inférence LoRA avec GPU, le backend doit être défini sur 'gpu'.

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

Le convertisseur affiche deux fichiers de tampon plat TFLite, l'un pour le modèle de base et l'autre pour le modèle LoRA.

Paramètres Description Valeurs acceptées
input_ckpt Chemin d'accès au fichier model.safetensors ou pytorch.bin. Notez que parfois, le format safetensors du modèle est segmenté en plusieurs fichiers, par exemple model-00001-of-00003.safetensors ou model-00001-of-00003.safetensors. Vous pouvez spécifier un format de fichier, par exemple model*.safetensors. PATH
ckpt_format Format de fichier du modèle. {"safetensors", "pytorch"}
model_type Le LLM en cours de conversion. {"PHI_2", "FALCON_RW_1B", "STABLELM_4E1T_3B", "GEMMA_2B"}
backend Processeur (délégué) utilisé pour exécuter le modèle. {"cpu", "gpu"}
output_dir Chemin d'accès au répertoire de sortie qui héberge les fichiers de pondération par couche. PATH
output_tflite_file Chemin d'accès au fichier de sortie. Par exemple, "model_cpu.bin" ou "model_gpu.bin". Ce fichier n'est compatible qu'avec l'API LLM Inference et ne peut pas être utilisé comme fichier "tflite" général. PATH
vocab_model_file Chemin d'accès au répertoire qui stocke les fichiers tokenizer.json et tokenizer_config.json. Pour Gemma, pointez sur le seul fichier tokenizer.model. PATH
lora_ckpt Chemin d'accès au ckpt LoRA des safetensors qui stocke le poids de l'adaptateur LoRA. PATH
lora_rank Entier représentant le rang de ckpt LoRA. Obligatoire pour convertir les pondérations Lora. S'il n'est pas fourni, le convertisseur suppose qu'il n'y a pas de pondérations LoRA. Remarque: Seul le backend de GPU est compatible avec la fonctionnalité LoRA. Entier
lora_output_tflite_file Nom de fichier tflite de sortie pour les pondérations LoRA. PATH

Conversion de modèles AI Edge

Si vous utilisez un LLM mappé sur un modèle TFLite via AI Edge, créez un groupe de tâches à l'aide de notre script de regroupement. Le processus de regroupement empaquette le modèle mappé avec des métadonnées supplémentaires Tokenizer) nécessaires à l'exécution de l'inférence de bout en bout.

Le processus de regroupement des modèles nécessite le package PyPI MediaPipe. Le script de conversion est disponible dans tous les packages MediaPipe après 0.10.14.

Installez et importez les dépendances avec la commande suivante:

$ python3 -m pip install mediapipe

Utilisez la bibliothèque genai.bundler pour regrouper le modèle:

import mediapipe as mp
from mediapipe.tasks.python.genai import bundler

config = bundler.BundleConfig(
    tflite_model=TFLITE_MODEL,
    tokenizer_model=TOKENIZER_MODEL,
    start_token=START_TOKEN,
    stop_tokens=STOP_TOKENS,
    output_filename=OUTPUT_FILENAME,
    enable_bytes_to_unicode_mapping=ENABLE_BYTES_TO_UNICODE_MAPPING,
)
bundler.create_bundle(config)
Paramètres Description Valeurs acceptées
tflite_model Chemin d'accès au modèle TFLite exporté par AI Edge. PATH
tokenizer_model Chemin d'accès au modèle de tokenisation SentencePiece. PATH
start_token Jeton de démarrage spécifique au modèle. Le jeton de début doit être présent dans le modèle de tokenisation fourni. STRING
stop_tokens Modélisez des jetons d'arrêt spécifiques. Les jetons d'arrêt doivent être présents dans le modèle de tokenisation fourni. LISTE[CHAÎNE]
output_filename Nom du fichier du groupe de tâches en sortie. PATH

Transférer le modèle sur l'appareil

Transférez le contenu du dossier output_path sur l'appareil Android.

$ adb shell rm -r /data/local/tmp/llm/ # Remove any previously loaded models
$ adb shell mkdir -p /data/local/tmp/llm/
$ adb push output_path /data/local/tmp/llm/model_version.bin

Créer la tâche

L'API LLM Inference de MediaPipe utilise la fonction createFromOptions() pour configurer la tâche. La fonction createFromOptions() accepte des valeurs pour les options de configuration. Pour en savoir plus sur les options de configuration, consultez la section Options de configuration.

Le code suivant initialise la tâche à l'aide des options de configuration de base:

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPATH('/data/local/.../')
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

Options de configuration

Utilisez les options de configuration suivantes pour configurer une application Android:

Nom de l'option Description Plage de valeurs Valeur par défaut
modelPath Chemin d'accès au stockage du modèle dans le répertoire du projet. PATH N/A
maxTokens Nombre maximal de jetons (jetons d'entrée + jetons de sortie) traités par le modèle. Entier 512
topK Nombre de jetons que le modèle considère à chaque étape de génération. Limite les prédictions aux k jetons les plus probables. Lorsque vous définissez topK, vous devez également définir une valeur pour randomSeed. Entier 40
temperature Quantité de hasard introduit lors de la génération. Une température plus élevée accroît la créativité dans le texte généré, tandis qu'une température plus basse produit une génération plus prévisible. Lorsque vous définissez temperature, vous devez également définir une valeur pour randomSeed. Nombre à virgule flottante 0,8
randomSeed Valeur initiale aléatoire utilisée lors de la génération de texte. Entier 0
loraPath Chemin absolu vers le modèle LoRA localement sur l'appareil. Remarque: Ceci n'est compatible qu'avec les modèles de GPU. PATH N/A
resultListener Définit l'écouteur de résultats pour recevoir les résultats de manière asynchrone. Ne s'applique que lors de l'utilisation de la méthode de génération asynchrone. N/A N/A
errorListener Définit un écouteur d'erreurs facultatif. N/A N/A

Préparation des données

L'API LLM Inference accepte les entrées suivantes:

  • prompt (chaîne): question ou invite.
val inputPrompt = "Compose an email to remind Brett of lunch plans at noon on Saturday."

Exécuter la tâche

Utilisez la méthode generateResponse() pour générer une réponse textuelle au texte d'entrée fourni dans la section précédente (inputPrompt). Cela produit une seule réponse générée.

val result = llmInference.generateResponse(inputPrompt)
logger.atInfo().log("result: $result")

Pour diffuser la réponse, utilisez la méthode generateResponseAsync().

val options = LlmInference.LlmInferenceOptions.builder()
  ...
  .setResultListener { partialResult, done ->
    logger.atInfo().log("partial result: $partialResult")
  }
  .build()

llmInference.generateResponseAsync(inputPrompt)

Gérer et afficher les résultats

L'API LLM Inference renvoie un LlmInferenceResult, qui inclut le texte de réponse généré.

Here's a draft you can use:

Subject: Lunch on Saturday Reminder

Hi Brett,

Just a quick reminder about our lunch plans this Saturday at noon.
Let me know if that still works for you.

Looking forward to it!

Best,
[Your Name]

Personnalisation du modèle LoRA

L'API d'inférence LLM de Mediapipe peut être configurée pour prendre en charge l'adaptation de rang faible (LoRA) pour les grands modèles de langage. À l'aide de modèles LoRA affinés, les développeurs peuvent personnaliser le comportement des LLM via un processus d'entraînement économique.

La prise en charge de la bibliothèque LoRA de l'API LLM Inference fonctionne pour les modèles Gemma-2B et Phi-2 pour le backend GPU, avec des pondérations LoRA applicables uniquement aux couches d'attention. Cette implémentation initiale sert d'API expérimentale pour les développements futurs. Nous prévoyons de prendre en charge davantage de modèles et différents types de couches dans les prochaines mises à jour.

Préparer des modèles de LoRA

Suivez les instructions sur HuggingFace pour entraîner un modèle LoRA affiné sur votre propre ensemble de données avec les types de modèles compatibles, Gemma-2B ou Phi-2. Les modèles Gemma-2B et Phi-2 sont tous deux disponibles sur HuggingFace au format Safetensors. Étant donné que l'API LLM Inference n'est compatible qu'avec les couches d'attention (LLM Inference), vous ne devez spécifier que les couches d'attention lors de la création d'une LoraConfig, comme suit:

# For Gemma-2B
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

Pour les tests, il existe des modèles LoRA réglés et affinés accessibles publiquement, qui s'adaptent à l'API LLM Inference disponible sur HuggingFace. Par exemple, monsterapi/gemma-2b-lora-maths-orca-200k pour Gemma-2B et lole25/phi-2-sft-ultrachat-lora pour Phi-2.

Après l'entraînement sur l'ensemble de données préparé et l'enregistrement du modèle, vous obtenez un fichier adapter_model.safetensors contenant les pondérations du modèle LoRA affinées. Le fichier safetensors correspond au point de contrôle LoRA utilisé dans la conversion du modèle.

À l'étape suivante, vous devez convertir les pondérations du modèle en un Flatbuffer TensorFlow Lite à l'aide du package Python MediaPipe. L'élément ConversionConfig doit spécifier les options du modèle de base ainsi que d'autres options de LoRA. Notez que, comme l'API n'accepte que l'inférence LoRA avec GPU, le backend doit être défini sur 'gpu'.

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

Le convertisseur affiche deux fichiers de tampon plat TFLite, l'un pour le modèle de base et l'autre pour le modèle LoRA.

Inférence de modèle LoRA

L'API LLM Inference Web, Android et iOS a été mise à jour pour assurer la compatibilité avec l'inférence de modèle LoRA. Le Web prend en charge des modèles de LoRA dynamiques, qui peuvent changer de modèle de LoRA pendant l'exécution. Android et iOS prennent en charge la fonctionnalité LoRA statique, qui utilise les mêmes pondérations LoRA pendant la durée de vie de la tâche.

Android prend en charge la fonctionnalité LoRA statique pendant l'initialisation. Pour charger un modèle LoRA, les utilisateurs doivent spécifier le chemin d'accès au modèle LoRA ainsi que le LLM de base.

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPath('<path to base model>')
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .setLoraPath('<path to LoRA model>')
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

Pour exécuter l'inférence LLM avec LoRA, utilisez les mêmes méthodes generateResponse() ou generateResponseAsync() que le modèle de base.