Guida all'inferenza LLM per Android

L'API LLM Inference ti consente di eseguire modelli linguistici di grandi dimensioni (LLM) completamente on-device per le applicazioni Android, che puoi utilizzare per svolgere un'ampia gamma di attività, come la generazione di testo, il recupero di informazioni in forma di linguaggio naturale e il riepilogo dei documenti. L'attività fornisce il supporto integrato per più modelli linguistici di grandi dimensioni text-to-text, in modo da poter applicare i modelli di IA generativa on-device più recenti alle tue app per Android.

Per aggiungere rapidamente l'API LLM Inference alla tua applicazione Android, segui la guida rapida. Per un esempio di base di un'applicazione Android che esegue l'API LLM Inference, consulta l'applicazione di esempio. Per una conoscenza più approfondita del funzionamento dell'API LLM Inference, consulta le sezioni Opzioni di configurazione, Conversione del modello e Tuning LoRa.

Puoi vedere questa operazione in azione con la demo di MediaPipe Studio. Per saperne di più sulle funzionalità, sui modelli e sulle opzioni di configurazione di questa attività, consulta la Panoramica.

Guida rapida

Segui questa procedura per aggiungere l'API LLM Inference alla tua applicazione Android. L'API LLM Inference è ottimizzata per i dispositivi Android di fascia alta, come Pixel 8 e Samsung S23 o modelli successivi, e non supporta in modo affidabile gli emulatori di dispositivi.

Aggiungi dipendenze

L'API LLM Inference utilizza la libreria com.google.mediapipe:tasks-genai. Aggiungi questa dipendenza al file build.gradle della tua app per Android:

dependencies {
    implementation 'com.google.mediapipe:tasks-genai:0.10.24'
}

Scaricare un modello

Scarica Gemma-3 1B in un formato quantizzato a 4 bit da Hugging Face. Per ulteriori informazioni sui modelli disponibili, consulta la documentazione relativa ai modelli.

Invia i contenuti della cartella output_path al dispositivo Android.

$ adb shell rm -r /data/local/tmp/llm/ # Remove any previously loaded models
$ adb shell mkdir -p /data/local/tmp/llm/
$ adb push output_path /data/local/tmp/llm/model_version.task

Inizializza l'attività

Inizializza l'attività con le opzioni di configurazione di base:

// Set the configuration options for the LLM Inference task
val taskOptions = LlmInferenceOptions.builder()
        .setModelPath('/data/local/tmp/llm/model_version.task')
        .setMaxTopK(64)
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, taskOptions)

Esegui l'attività

Utilizza il metodo generateResponse() per generare una risposta di testo. Viene generata una singola risposta.

val result = llmInference.generateResponse(inputPrompt)
logger.atInfo().log("result: $result")

Per trasmettere la risposta in streaming, utilizza il metodo generateResponseAsync().

val options = LlmInference.LlmInferenceOptions.builder()
  ...
  .setResultListener { partialResult, done ->
    logger.atInfo().log("partial result: $partialResult")
  }
  .build()

llmInference.generateResponseAsync(inputPrompt)

Applicazione di esempio

Per vedere le API di inferenza LLM in azione ed esplorare una gamma completa di funzionalità di IA generativa on-device, dai un'occhiata all'app Galleria Edge di Google AI.

La galleria di Google AI Edge è un'applicazione Android open source che funge da area di sperimentazione interattiva per gli sviluppatori. Mette in evidenza:

  • Esempi pratici di utilizzo dell'API LLM Inference per varie attività, tra cui:
    • Chiedi Immagine: carica un'immagine e fai domande al riguardo. Ricevi descrizioni, risolvi problemi o identifica oggetti.
    • Prompt Lab: riassumere, riscrivere, generare codice o utilizzare prompt in formato libero per esplorare i casi d'uso degli LLM con un solo turno.
    • Chat AI: partecipa a conversazioni multi-turno.
  • La possibilità di scoprire, scaricare ed eseguire esperimenti su una serie di modelli ottimizzati per LiteRT della community Hugging Face LiteRT e delle release ufficiali di Google (ad es. Gemma 3N).
  • Benchmark delle prestazioni on-device in tempo reale per diversi modelli (tempo per il primo token, velocità di decodifica e così via).
  • Come importare e testare i tuoi modelli .task personalizzati.

Questa app è una risorsa per comprendere l'implementazione pratica dell'API di inferenza LLM e il potenziale dell'IA generativa on-device. Esplora il codice sorgente e scarica l'app dal repository GitHub della galleria di Google AI Edge.

Opzioni di configurazione

Utilizza le seguenti opzioni di configurazione per configurare un'app per Android:

Nome opzione Descrizione Intervallo di valori Valore predefinito
modelPath Il percorso in cui è archiviato il modello all'interno della directory del progetto. PERCORSO N/D
maxTokens Il numero massimo di token (token di input + token di output) gestiti dal modello. Numero intero 512
topK Il numero di token presi in considerazione dal modello in ogni fase di generazione. Limita le previsioni ai token più probabili tra i primi k. Numero intero 40
temperature La quantità di casualità introdotta durante la generazione. Una temperatura più alta consente di ottenere un testo generato più creativo, mentre una temperatura più bassa produce una generazione più prevedibile. Float 0,8
randomSeed Il seed casuale utilizzato durante la generazione del testo. Numero intero 0
loraPath Il percorso assoluto del modello LoRA localmente sul dispositivo. Nota: questa opzione è compatibile solo con i modelli GPU. PERCORSO N/D
resultListener Imposta l'ascoltatore dei risultati in modo che riceva i risultati in modo asincrono. Applicabile solo quando si utilizza il metodo di generazione asincrona. N/D N/D
errorListener Imposta un listener di errore facoltativo. N/D N/D

Prompt multimodale

Le API Android dell'API di inferenza LLM supportano i prompt multimodali con modelli che accettano input di testo e immagini. Con la multimodalità abilitata, gli utenti possono includere una combinazione di immagini e testo nei prompt e l'LLM fornisce una risposta di testo.

Per iniziare, utilizza una variante di Gemma 3n compatibile con MediaPipe:

Per ulteriori informazioni, consulta la documentazione di Gemma-3n.

Per fornire immagini all'interno di un prompt, converti le immagini o i frame di input in un oggetto com.google.mediapipe.framework.image.MPImage prima di passarlo all'API LLM Inference:

import com.google.mediapipe.framework.image.BitmapImageBuilder
import com.google.mediapipe.framework.image.MPImage

// Convert the input Bitmap object to an MPImage object to run inference
val mpImage = BitmapImageBuilder(image).build()

Per attivare il supporto della visione per l'API LLM Inference, imposta l'opzione di configurazione EnableVisionModality su true nelle opzioni del grafico:

LlmInferenceSession.LlmInferenceSessionOptions sessionOptions =
  LlmInferenceSession.LlmInferenceSessionOptions.builder()
    ...
    .setGraphOptions(GraphOptions.builder().setEnableVisionModality(true).build())
    .build();

Gemma-3n accetta un massimo di un'immagine per sessione, quindi imposta MaxNumImages su 1.

LlmInferenceOptions options = LlmInferenceOptions.builder()
  ...
  .setMaxNumImages(1)
  .build();

Di seguito è riportato un esempio di implementazione dell'API di inferenza LLM configurata per gestire input di visione e testo:

MPImage image = getImageFromAsset(BURGER_IMAGE);

LlmInferenceSession.LlmInferenceSessionOptions sessionOptions =
  LlmInferenceSession.LlmInferenceSessionOptions.builder()
    .setTopK(10)
    .setTemperature(0.4f)
    .setGraphOptions(GraphOptions.builder().setEnableVisionModality(true).build())
    .build();

try (LlmInference llmInference =
    LlmInference.createFromOptions(ApplicationProvider.getApplicationContext(), options);
  LlmInferenceSession session =
    LlmInferenceSession.createFromOptions(llmInference, sessionOptions)) {
  session.addQueryChunk("Describe the objects in the image.");
  session.addImage(image);
  String result = session.generateResponse();
}

Personalizzazione LoRA

L'API di inferenza LLM supporta l'ottimizzazione LoRA (Low-Rank Adaptation) utilizzando la libreria PEFT (Parameter-Efficient Fine-Tuning). La regolazione LoRA personalizza il comportamento degli LLM tramite un procedura di addestramento economica, creando un piccolo insieme di pesi addestrabili basati su nuovi dati di addestramento anziché addestrare nuovamente l'intero modello.

L'API LLM Inference supporta l'aggiunta di pesi LoRA ai livelli di attenzione dei modelli Gemma-2 2B, Gemma 2B e Phi-2. Scarica il modello nel formato safetensors.

Il modello di base deve essere nel formato safetensors per creare i pesi LoRA. Dopo l'addestramento LoRA, puoi convertire i modelli nel formato FlatBuffers per eseguirli su MediaPipe.

Prepara i pesi LoRA

Utilizza la guida Metodi LoRA di PEFT per addestrare un modello LoRA ottimizzato sul tuo set di dati.

L'API LLM Inference supporta solo LoRA nei livelli di attenzione, quindi specifica solo i livelli di attenzione in LoraConfig:

# For Gemma
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

Dopo l'addestramento sul set di dati preparato e il salvataggio del modello, i pesi del modello LoRA ottimizzato sono disponibili in adapter_model.safetensors. Il file safetensors è il checkpoint LoRA utilizzato durante la conversione del modello.

Conversione del modello

Utilizza il pacchetto Python MediaPipe per convertire i pesi del modello nel formato Flatbuffer. ConversionConfig specifica le opzioni del modello di base insieme alle opzioni LoRa aggiuntive.

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_FILE,
)

converter.convert_checkpoint(config)

Il convertitore produrrà due file Flatbuffer, uno per il modello di base e un altro per il modello LoRA.

Inferenza del modello LoRA

Android supporta LoRa statico durante l'inizializzazione. Per caricare un modello LoRA, specifica il percorso del modello LoRA e l'LLM di base.

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPath(BASE_MODEL_PATH)
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .setLoraPath(LORA_MODEL_PATH)
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

Per eseguire l'inferenza LLM con LoRA, utilizza gli stessi metodi generateResponse() o generateResponseAsync() del modello di base.