Guida all'inferenza LLM per Android

L'API LLM Inference ti consente di eseguire modelli linguistici di grandi dimensioni (LLM) completamente on-device per le applicazioni Android, che puoi utilizzare per svolgere un'ampia gamma di attività, come la generazione di testo, il recupero di informazioni in forma di linguaggio naturale e il riepilogo dei documenti. L'attività fornisce il supporto integrato per più modelli linguistici di grandi dimensioni text-to-text, in modo da poter applicare i modelli di IA generativa on-device più recenti alle tue app per Android.

L'attività supporta le seguenti varianti di Gemma: Gemma-3 1B, Gemma-2 2B, Gemma 2B e Gemma 7B. Gemma è una famiglia di modelli aperti leggeri e all'avanguardia creati sulla base della stessa ricerca e tecnologia utilizzata per creare i modelli Gemini. Supporta anche i seguenti modelli esterni: Phi-2, Falcon-RW-1B e StableLM-3B.

Oltre ai modelli supportati, gli utenti possono utilizzare AI Edge Torch di Google per esportare i modelli PyTorch in modelli LiteRT (tflite) con più firme, che sono raggruppati con i parametri del tokenizer per creare Task Bundle compatibili con l'API di inferenza LLM.

Puoi vedere questa operazione in azione con la demo di MediaPipe Studio. Per saperne di più sulle funzionalità, sui modelli e sulle opzioni di configurazione di questa attività, consulta la Panoramica.

Esempio di codice

Questa guida fa riferimento a un esempio di app di generazione di testo di base per Android. Puoi utilizzare l'app come punto di partenza per la tua app per Android o farvi riferimento quando modifichi un'app esistente. Il codice di esempio è ospitato su GitHub.

Scarica il codice

Le istruzioni riportate di seguito mostrano come creare una copia locale del codice di esempio utilizzando lo strumento a riga di comando git.

Per scaricare il codice di esempio:

  1. Clona il repository git utilizzando il seguente comando:
    git clone https://github.com/google-ai-edge/mediapipe-samples
    
  2. Facoltativamente, configura l'istanza Git in modo da utilizzare il controllo sparse, in modo da avere solo i file per l'app di esempio dell'API LLM Inference:
    cd mediapipe-samples
    git sparse-checkout init --cone
    git sparse-checkout set examples/llm_inference/android
    

Dopo aver creato una versione locale del codice di esempio, puoi importare il progetto in Android Studio ed eseguire l'app. Per istruzioni, consulta la Guida alla configurazione per Android.

Configurazione

Questa sezione descrive i passaggi chiave per configurare l'ambiente di sviluppo e i progetti di codice in modo specifico per utilizzare l'API LLM Inference. Per informazioni generali sulla configurazione dell'ambiente di sviluppo per l'utilizzo delle attività MediaPipe, inclusi i requisiti della versione della piattaforma, consulta la guida alla configurazione per Android.

Dipendenze

L'API LLM Inference utilizza la libreria com.google.mediapipe:tasks-genai. Aggiungi questa dipendenza al file build.gradle della tua app per Android:

dependencies {
    implementation 'com.google.mediapipe:tasks-genai:0.10.22'
}

Per i dispositivi con Android 12 (API 31) o versioni successive, aggiungi la dipendenza dalla libreria OpenCL nativa. Per ulteriori informazioni, consulta la documentazione relativa al tag uses-native-library.

Aggiungi i seguenti tag uses-native-library al file AndroidManifest.xml:

<uses-native-library android:name="libOpenCL.so" android:required="false"/>
<uses-native-library android:name="libOpenCL-car.so" android:required="false"/>
<uses-native-library android:name="libOpenCL-pixel.so" android:required="false"/>

Modello

L'API MediaPipe LLM Inference richiede un modello linguistico di conversione da testo a testo addestrato compatibile con questa attività. Dopo aver scaricato un modello, installa le dipendenze richieste e invia il modello al dispositivo Android. Se utilizzi un modello diverso da Gemma, dovrai convertirlo in un formato compatibile con MediaPipe.

Per ulteriori informazioni sui modelli addestrati disponibili per l'API LLM Inference, consulta la sezione Modelli della panoramica dell'attività.

Scaricare un modello

Prima di inizializzare l'API LLM Inference, scarica uno dei modelli supportati e memorizza il file nella directory del progetto. Ti consigliamo di utilizzare Gemma-3 1B in un formato quantizzato a 4 bit, disponibile su Hugging Face.

Puoi anche scaricare uno degli altri modelli supportati:

  • Gemma-2 2B: la seconda generazione di modelli Gemma. Fa parte di una famiglia di modelli aperti leggeri e all'avanguardia creati sulla base della stessa ricerca e tecnologia utilizzata per creare i modelli Gemini.
  • Gemma 2B: appartiene a una famiglia di modelli aperti leggeri e all'avanguardia creati sulla base della stessa ricerca e tecnologia utilizzata per creare i modelli Gemini. Molto adatto per una serie di attività di generazione di testo, tra cui risposta a domande, riassunto e ragionamento.
  • Phi-2: modello Transformer con 2, 7 miliardi di parametri, più adatto per il formato di domande e risposte, chat e codice.
  • Falcon-RW-1B: modello causale solo decoder con 1 miliardo di parametri addestrato su 350 miliardi di token di RefinedWeb.
  • StableLM-3B: modello linguistico solo decodificatore con 3 miliardi di parametri preaddestrato su 1 trilione di token di diversi set di dati in inglese e codice.

Oltre ai modelli supportati, puoi utilizzare AI Edge Torch di Google per esportare i modelli PyTorch in modelli LiteRT (tflite) con più firme. Per ulteriori informazioni, consulta Torch Generatore di modelli per i modelli PyTorch.

Per ulteriori informazioni sui modelli disponibili, consulta la sezione Modelli nella panoramica dell'attività.

Converti il modello in formato MediaPipe

L'API LLM Inference è compatibile con i seguenti tipi di modelli, alcuni dei quali richiedono la conversione del modello. Utilizza la tabella per identificare il metodo di procedura richiesto per il tuo modello.

Modelli Metodo di conversione Piattaforme compatibili Tipo di file
Gemma-3 1B Nessuna conversione richiesta Android, web .task
Gemma 2B, Gemma 7B, Gemma-2 2B Nessuna conversione richiesta Android, iOS, web .bin
Phi-2, StableLM, Falcon Script di conversione MediaPipe Android, iOS, web .bin
Tutti i modelli LLM di PyTorch AI Edge Torch Generative library Android, iOS .task

Per scoprire come convertire altri modelli, consulta la sezione Conversione del modello.

Invia il modello al dispositivo

Invia i contenuti della cartella output_path al dispositivo Android.

$ adb shell rm -r /data/local/tmp/llm/ # Remove any previously loaded models
$ adb shell mkdir -p /data/local/tmp/llm/
$ adb push output_path /data/local/tmp/llm/model_version.task

Crea l'attività

L'API MediaPipe LLM Inference utilizza la funzione createFromOptions() per configurare il compito. La funzione createFromOptions() accetta valori per le opzioni di configurazione. Per ulteriori informazioni sulle opzioni di configurazione, consulta Opzioni di configurazione.

Il seguente codice inizializza l'attività utilizzando le opzioni di configurazione di base:

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPath("/data/local/.../")
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

Opzioni di configurazione

Utilizza le seguenti opzioni di configurazione per configurare un'app per Android:

Nome opzione Descrizione Intervallo di valori Valore predefinito
modelPath Il percorso in cui è archiviato il modello all'interno della directory del progetto. PERCORSO N/D
maxTokens Il numero massimo di token (token di input + token di output) gestiti dal modello. Numero intero 512
topK Il numero di token presi in considerazione dal modello in ogni fase di generazione. Limita le previsioni ai token più probabili tra i primi k. Numero intero 40
temperature L'entità della casualità introdotta durante la generazione. Una temperatura più alta consente di ottenere un testo generato più creativo, mentre una temperatura più bassa produce una generazione più prevedibile. Float 0,8
randomSeed Il seed casuale utilizzato durante la generazione del testo. Numero intero 0
loraPath Il percorso assoluto del modello LoRA localmente sul dispositivo. Nota: questa opzione è compatibile solo con i modelli GPU. PERCORSO N/D
resultListener Imposta l'ascoltatore dei risultati in modo che riceva i risultati in modo asincrono. Applicabile solo quando si utilizza il metodo di generazione asincrona. N/D N/D
errorListener Imposta un listener di errore facoltativo. N/D N/D

Preparazione dei dati

L'API LLM Inference accetta i seguenti input:

  • prompt (stringa): una domanda o un prompt.
val inputPrompt = "Compose an email to remind Brett of lunch plans at noon on Saturday."

Esegui l'attività

Utilizza il metodo generateResponse() per generare una risposta di testo al testo di input fornito nella sezione precedente (inputPrompt). Viene prodotta una singola risposta generata.

val result = llmInference.generateResponse(inputPrompt)
logger.atInfo().log("result: $result")

Per trasmettere la risposta in streaming, utilizza il metodo generateResponseAsync().

val options = LlmInference.LlmInferenceOptions.builder()
  ...
  .setResultListener { partialResult, done ->
    logger.atInfo().log("partial result: $partialResult")
  }
  .build()

llmInference.generateResponseAsync(inputPrompt)

Gestire e visualizzare i risultati

L'API LLM Inference restituisce un LlmInferenceResult, che include il testo della risposta generata.

Here's a draft you can use:

Subject: Lunch on Saturday Reminder

Hi Brett,

Just a quick reminder about our lunch plans this Saturday at noon.
Let me know if that still works for you.

Looking forward to it!

Best,
[Your Name]

Personalizzazione del modello LoRA

L'API di inferenza LLM di Mediapipe può essere configurata per supportare l'adattamento a basso ranking (LoRA) per i modelli linguistici di grandi dimensioni. Utilizzando modelli LoRA ottimizzati, gli sviluppatori possono personalizzare il comportamento degli LLM tramite un processo di addestramento conveniente.

Il supporto LoRA dell'API LLM Inference funziona per tutte le varianti di Gemma e per i modelli Phi-2 per il backend GPU, con i pesi LoRA applicabili solo ai livelli di attenzione. Questa implementazione iniziale funge da API sperimentale per sviluppi futuri, con piani per supportare più modelli e vari tipi di livelli nei prossimi aggiornamenti.

Prepara i modelli LoRA

Segui le istruzioni su HuggingFace per addestrare un modello LoRA ottimizzato sul tuo set di dati con i tipi di modelli supportati, Gemma o Phi-2. I modelli Gemma-2 2B, Gemma 2B e Phi-2 sono entrambi disponibili su HuggingFace nel formato safetensors. Poiché l'API LLM Inference supporta solo LoRA nei livelli di attenzione, specifica solo i livelli di attenzione durante la creazione di LoraConfig come segue:

# For Gemma
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

Per i test, sono disponibili su HuggingFace modelli LoRA ottimizzati e accessibili pubblicamente che si adattano all'API di inferenza LLM. Ad esempio, monsterapi/gemma-2b-lora-maths-orca-200k per Gemma-2B e lole25/phi-2-sft-ultrachat-lora per Phi-2.

Dopo l'addestramento sul set di dati preparato e il salvataggio del modello, ottieni un file adapter_model.safetensors contenente i pesi del modello LoRA perfezionato. Il file safetensors è il checkpoint LoRA utilizzato nella conversione del modello.

Come passaggio successivo, devi convertire i pesi del modello in un Flatbuffer di TensorFlow Lite utilizzando il pacchetto Python MediaPipe. ConversionConfig deve specificare le opzioni del modello di base e altre opzioni LoRa. Tieni presente che, poiché l'API supporta l'inferenza LoRa solo con GPU, il backend deve essere impostato su 'gpu'.

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

Il convertitore produrrà due file flatbuffer TFLite, uno per il modello di base e l'altro per il modello LoRA.

Inferenza del modello LoRA

L'API di inferenza LLM per web, Android e iOS è stata aggiornata per supportare l'inferenza del modello LoRA.

Android supporta LoRa statico durante l'inizializzazione. Per caricare un modello LoRA, gli utenti devono specificare il percorso del modello LoRA e l'LLM di base.

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPath('<path to base model>')
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .setLoraPath('<path to LoRA model>')
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

Per eseguire l'inferenza LLM con LoRA, utilizza gli stessi metodi generateResponse() o generateResponseAsync() del modello di base.