Guida all'inferenza LLM per il web

L'API LLM Inference ti consente di eseguire modelli linguistici di grandi dimensioni (LLM) completamente nel browser per le applicazioni web, che puoi utilizzare per eseguire un'ampia gamma di attività, come la generazione di testo, il recupero di informazioni in forma di linguaggio naturale e il riassunto dei documenti. L'attività fornisce il supporto integrato per più modelli linguistici di grandi dimensioni text-to-text, in modo da poter applicare i modelli di AI generativa on-device più recenti alle tue app web.

L'attività supporta le seguenti varianti di Gemma: Gemma-2 2B, Gemma 2B e Gemma 7B. Gemma è una famiglia di modelli aperti leggeri e all'avanguardia creati sulla base della stessa ricerca e tecnologia utilizzata per creare i modelli Gemini. Supporta inoltre i seguenti modelli esterni: Phi-2, Falcon-RW-1B e StableLM-3B.

Puoi vedere questa operazione in azione con la demo di MediaPipe Studio. Per ulteriori informazioni sulle funzionalità, sui modelli e sulle opzioni di configurazione di questa attività, consulta la Panoramica.

Esempio di codice

L'applicazione di esempio per l'API LLM Inference fornisce un'implementazione di base di questa attività in JavaScript come riferimento. Puoi utilizzare questa app di esempio per iniziare a creare la tua app di generazione di testo.

Puoi accedere all'app di esempio dell'API LLM Inference su GitHub.

Configurazione

Questa sezione descrive i passaggi chiave per configurare l'ambiente di sviluppo e i progetti di codice in modo specifico per utilizzare l'API LLM Inference. Per informazioni generali sulla configurazione dell'ambiente di sviluppo per l'utilizzo di MediaPipe Tasks, inclusi i requisiti della versione della piattaforma, consulta la guida alla configurazione per il web.

Compatibilità del browser

L'API LLM Inference richiede un browser web compatibile con WebGPU. Per un elenco completo dei browser compatibili, vedi Compatibilità del browser GPU.

Pacchetti JavaScript

Il codice dell'API di inferenza LLM è disponibile tramite il package @mediapipe/tasks-genai. Puoi trovare e scaricare queste librerie dai link forniti nella guida alla configurazione della piattaforma.

Installa i pacchetti necessari per l'implementazione locale:

npm install @mediapipe/tasks-genai

Per eseguire il deployment su un server, utilizza un servizio CDN (Content Delivery Network) come jsDelivr per aggiungere il codice direttamente alla pagina HTML:

<head>
  <script src="https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai/genai_bundle.cjs"
    crossorigin="anonymous"></script>
</head>

Modello

L'API MediaPipe LLM Inference richiede un modello addestrato compatibile con questa attività. Per le applicazioni web, il modello deve essere compatibile con le GPU.

Per ulteriori informazioni sui modelli addestrati disponibili per l'API LLM Inference, consulta la sezione Modelli della panoramica dell'attività.

Scaricare un modello

Prima di inizializzare l'API di inferenza LLM, scarica uno dei modelli supportati e memorizza il file nella directory del progetto:

  • Gemma-2 2B: la versione più recente della famiglia di modelli Gemma. Fa parte di una famiglia di modelli aperti leggeri e all'avanguardia creati sulla base della stessa ricerca e tecnologia utilizzata per creare i modelli Gemini.
  • Gemma 2B: appartiene a una famiglia di modelli aperti leggeri e all'avanguardia creati sulla base della stessa ricerca e tecnologia utilizzata per creare i modelli Gemini. Molto adatto per una serie di attività di generazione di testo, tra cui risposta a domande, riassunto e ragionamento.
  • Phi-2: modello Transformer con 2, 7 miliardi di parametri, più adatto per il formato di domande e risposte, chat e codice.
  • Falcon-RW-1B: modello causale solo decoder con 1 miliardo di parametri addestrato su 350 miliardi di token di RefinedWeb.
  • StableLM-3B: modello linguistico solo decodificatore con 3 miliardi di parametri preaddestrato su 1 triliardo di token di diversi set di dati di codice e in inglese.

Oltre ai modelli supportati, puoi utilizzare AI Edge Torch di Google per esportare i modelli PyTorch in modelli LiteRT (tflite) con più firme. Per ulteriori informazioni, consulta Torch Generative Converter per i modelli PyTorch.

Ti consigliamo di utilizzare Gemma-2 2B, disponibile su Kaggle Models. Per ulteriori informazioni sugli altri modelli disponibili, consulta la panoramica dell'attività sezione Modelli.

Converti il modello in formato MediaPipe

L'API LLM Inference è compatibile con due tipi di categorie di modelli, alcuni dei quali richiedono la conversione del modello. Utilizza la tabella per identificare il metodo di procedura richiesto per il tuo modello.

Modelli Metodo di conversione Piattaforme compatibili Tipo di file
Modelli supportati Gemma 2B, Gemma 7B, Gemma-2 2B, Phi-2, StableLM, Falcon MediaPipe Android, iOS, web .bin
Altri modelli PyTorch Tutti i modelli LLM di PyTorch Libreria Torch Generative di AI Edge Android, iOS .task

Stiamo ospitando i file .bin convertiti per Gemma 2B, Gemma 7B e Gemma-2 2B su Kaggle. Questi modelli possono essere implementati direttamente utilizzando la nostra API di inferenza LLM. Per scoprire come convertire altri modelli, consulta la sezione Conversione di modelli.

Aggiungere il modello alla directory del progetto

Archivia il modello nella directory del progetto:

<dev-project-root>/assets/gemma-2b-it-gpu-int4.bin

Specifica il percorso del modello con il parametro baseOptions dell'oggetto modelAssetPath:

baseOptions: { modelAssetPath: `/assets/gemma-2b-it-gpu-int4.bin`}

Crea l'attività

Utilizza una delle funzioni createFrom...() dell'API LLM Inference per preparare l'attività per l'esecuzione delle inferenze. Puoi utilizzare la funzione createFromModelPath() con un percorso relativo o assoluto al file del modello addestrato. L'esempio di codice utilizza la funzione createFromOptions(). Per ulteriori informazioni sulle opzioni di configurazione disponibili, consulta Opzioni di configurazione.

Il codice seguente mostra come creare e configurare questa attività:

const genai = await FilesetResolver.forGenAiTasks(
    // path/to/wasm/root
    "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai@latest/wasm"
);
llmInference = await LlmInference.createFromOptions(genai, {
    baseOptions: {
        modelAssetPath: '/assets/gemma-2b-it-gpu-int4.bin'
    },
    maxTokens: 1000,
    topK: 40,
    temperature: 0.8,
    randomSeed: 101
});

Opzioni di configurazione

Questa attività offre le seguenti opzioni di configurazione per le app web e JavaScript:

Nome opzione Descrizione Intervallo di valori Valore predefinito
modelPath Il percorso in cui è archiviato il modello all'interno della directory del progetto. PERCORSO N/D
maxTokens Il numero massimo di token (token di input + token di output) gestiti dal modello. Numero intero 512
topK Il numero di token presi in considerazione dal modello in ogni fase di generazione. Limita le previsioni ai token più probabili tra i primi k. Numero intero 40
temperature La quantità di casualità introdotta durante la generazione. Una temperatura più alta consente di ottenere un testo generato più creativo, mentre una temperatura più bassa produce una generazione più prevedibile. Float 0,8
randomSeed Il seed casuale utilizzato durante la generazione del testo. Numero intero 0
loraRanks I ranking LoRA da utilizzare dai modelli LoRA durante l'esecuzione. Nota: questa opzione è compatibile solo con i modelli GPU. Array di numeri interi N/D

Preparazione dei dati

L'API LLM Inference accetta dati di testo (string). L'attività gestisce la preelaborazione dei dati di input, inclusa la tokenizzazione e la preelaborazione dei tensori.

Tutta la pre-elaborazione viene gestita all'interno della funzione generateResponse(). Non è necessaria un'ulteriore preelaborazione del testo inserito.

const inputPrompt = "Compose an email to remind Brett of lunch plans at noon on Saturday.";

Esegui l'attività

L'API di inferenza LLM utilizza la funzione generateResponse() per attivare le inferenze. Per la classificazione del testo, significa restituire le possibili categorie per il testo di input.

Il seguente codice mostra come eseguire l'elaborazione con il modello di attività.

const response = await llmInference.generateResponse(inputPrompt);
document.getElementById('output').textContent = response;

Per riprodurre in streaming la risposta, utilizza quanto segue:

llmInference.generateResponse(
  inputPrompt,
  (partialResult, done) => {
        document.getElementById('output').textContent += partialResult;
});

Gestire e visualizzare i risultati

L'API LLM Inference restituisce una stringa che include il testo della risposta generata.

Here's a draft you can use:

Subject: Lunch on Saturday Reminder

Hi Brett,

Just a quick reminder about our lunch plans this Saturday at noon.
Let me know if that still works for you.

Looking forward to it!

Best,
[Your Name]

Personalizzazione del modello LoRA

L'API di inferenza LLM di Mediapipe può essere configurata per supportare l'adattamento a basso ranking (LoRA) per i modelli linguistici di grandi dimensioni. Utilizzando modelli LoRA ottimizzati, gli sviluppatori possono personalizzare il comportamento degli LLM tramite un processo di addestramento conveniente.

Il supporto LoRA dell'API LLM Inference funziona per tutte le varianti di Gemma e per i modelli Phi-2 per il backend GPU, con i pesi LoRA applicabili solo ai livelli di attenzione. Questa implementazione iniziale funge da API sperimentale per sviluppi futuri, con piani per supportare più modelli e vari tipi di livelli nei prossimi aggiornamenti.

Prepara i modelli LoRA

Segui le istruzioni su HuggingFace per addestrare un modello LoRA ottimizzato sul tuo set di dati con i tipi di modelli supportati, Gemma o Phi-2. I modelli Gemma-2 2B, Gemma 2B e Phi-2 sono entrambi disponibili su HuggingFace nel formato safetensors. Poiché l'API LLM Inference supporta solo LoRA nei livelli di attenzione, specifica solo i livelli di attenzione durante la creazione di LoraConfig come segue:

# For Gemma
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

Per i test, sono disponibili modelli LoRA ottimizzati e accessibili al pubblico che si adattano all'API di inferenza LLM su HuggingFace. Ad esempio, monsterapi/gemma-2b-lora-maths-orca-200k per Gemma-2B e lole25/phi-2-sft-ultrachat-lora per Phi-2.

Dopo l'addestramento sul set di dati preparato e il salvataggio del modello, ottieni un file adapter_model.safetensors contenente i pesi del modello LoRA perfezionato. Il file safetensors è il checkpoint LoRA utilizzato nella conversione del modello.

Come passaggio successivo, devi convertire i pesi del modello in un Flatbuffer di TensorFlow Lite utilizzando il pacchetto Python MediaPipe. ConversionConfig deve specificare le opzioni del modello di base e altre opzioni LoRa. Tieni presente che, poiché l'API supporta l'inferenza LoRa solo con GPU, il backend deve essere impostato su 'gpu'.

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

Il convertitore produrrà due file flatbuffer TFLite, uno per il modello di base e l'altro per il modello LoRA.

Inferenza del modello LoRA

L'API di inferenza LLM per web, Android e iOS è stata aggiornata per supportare l'inferenza del modello LoRA.

Il web supporta LoRa dinamico durante l'esecuzione. In altre parole, gli utenti dichiarano i ranghi LoRA che verranno utilizzati durante l'inizializzazione e possono scambiare diversi modelli LoRA durante l'esecuzione.

const genai = await FilesetResolver.forGenAiTasks(
    // path/to/wasm/root
    "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai@latest/wasm"
);
const llmInference = await LlmInference.createFromOptions(genai, {
    // options for the base model
    ...
    // LoRA ranks to be used by the LoRA models during runtime
    loraRanks: [4, 8, 16]
});

Durante l'esecuzione, dopo l'inizializzazione del modello di base, carica i modelli LoRa da utilizzare. Inoltre, attiva il modello LoRA passando il riferimento del modello LoRA durante la generazione della risposta LLM.

// Load several LoRA models. The returned LoRA model reference is used to specify
// which LoRA model to be used for inference.
loraModelRank4 = await llmInference.loadLoraModel(loraModelRank4Url);
loraModelRank8 = await llmInference.loadLoraModel(loraModelRank8Url);

// Specify LoRA model to be used during inference
llmInference.generateResponse(
  inputPrompt,
  loraModelRank4,
  (partialResult, done) => {
        document.getElementById('output').textContent += partialResult;
});