Guida all'inferenza LLM per il web

L'API LLM Inference consente di eseguire modelli linguistici di grandi dimensioni (LLM) completamente on-device per le applicazioni web, che puoi utilizzare per svolgere un'ampia gamma di attività, come la generazione di testo, il recupero di informazioni in forma di linguaggio naturale e il riepilogo dei documenti. L'attività fornisce il supporto integrato per più modelli linguistici di grandi dimensioni text-to-text, in modo da poter applicare i modelli di AI generativa on-device più recenti alle tue app web.

Per aggiungere rapidamente l'API LLM Inference alla tua applicazione web, segui la guida rapida. Per un esempio di base di un'applicazione web che esegue l'API LLM Inference, consulta l'applicazione di esempio. Per una conoscenza più approfondita del funzionamento dell'API LLM Inference, consulta le sezioni Opzioni di configurazione, Conversione del modello e Tuning LoRa.

Puoi vedere questa operazione in azione con la demo di MediaPipe Studio. Per saperne di più sulle funzionalità, sui modelli e sulle opzioni di configurazione di questa attività, consulta la Panoramica.

Guida rapida

Segui questa procedura per aggiungere l'API LLM Inference alla tua applicazione web. L'API LLM Inference richiede un browser web con compatibilità WebGPU. Per un elenco completo di browser compatibili, consulta Compatibilità del browser con GPU.

Aggiungi dipendenze

L'API LLM Inference utilizza il pacchetto @mediapipe/tasks-genai.

Installa i pacchetti richiesti per l'implementazione locale:

npm install @mediapipe/tasks-genai

Per eseguire il deployment su un server, utilizza un servizio CDN (Content Delivery Network) come jsDelivr per aggiungere il codice direttamente alla pagina HTML:

<head>
  <script src="https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai/genai_bundle.cjs"
    crossorigin="anonymous"></script>
</head>

Scaricare un modello

Scarica Gemma-2 2B in un formato quantizzato a 8 bit da Kaggle Models. Per ulteriori informazioni sui modelli disponibili, consulta la documentazione relativa ai modelli.

Archivia il modello nella directory del progetto:

<dev-project-root>/assets/gemma-2b-it-gpu-int8.bin

Specifica il percorso del modello con il parametro baseOptions dell'oggetto modelAssetPath:

baseOptions: { modelAssetPath: `/assets/gemma-2b-it-gpu-int8.bin`}

Inizializza l'attività

Inizializza l'attività con le opzioni di configurazione di base:

const genai = await FilesetResolver.forGenAiTasks(
    // path/to/wasm/root
    "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai@latest/wasm"
);
llmInference = await LlmInference.createFromOptions(genai, {
    baseOptions: {
        modelAssetPath: '/assets/gemma-2b-it-gpu-int8.bin'
    },
    maxTokens: 1000,
    topK: 40,
    temperature: 0.8,
    randomSeed: 101
});

Esegui l'attività

Utilizza la funzione generateResponse() per attivare le inferenze.

const response = await llmInference.generateResponse(inputPrompt);
document.getElementById('output').textContent = response;

Per riprodurre in streaming la risposta, utilizza quanto segue:

llmInference.generateResponse(
  inputPrompt,
  (partialResult, done) => {
        document.getElementById('output').textContent += partialResult;
});

Applicazione di esempio

L'applicazione di esempio è un esempio di app di generazione di testo di base per il web che utilizza l'API LLM Inference. Puoi utilizzare l'app come punto di partenza per la tua app web o farvi riferimento quando modifichi un'app esistente. Il codice di esempio è ospitato su GitHub.

Clona il repository git utilizzando il seguente comando:

git clone https://github.com/google-ai-edge/mediapipe-samples

Per ulteriori informazioni, consulta la guida alla configurazione per il web.

Opzioni di configurazione

Utilizza le seguenti opzioni di configurazione per configurare un'app web:

Nome opzione Descrizione Intervallo di valori Valore predefinito
modelPath Il percorso in cui è archiviato il modello all'interno della directory del progetto. PERCORSO N/D
maxTokens Il numero massimo di token (token di input + token di output) gestiti dal modello. Numero intero 512
topK Il numero di token presi in considerazione dal modello in ogni fase di generazione. Limita le previsioni ai token più probabili tra i primi k. Numero intero 40
temperature La quantità di casualità introdotta durante la generazione. Una temperatura più alta consente di ottenere un testo generato più creativo, mentre una temperatura più bassa produce una generazione più prevedibile. Float 0,8
randomSeed Il seed casuale utilizzato durante la generazione del testo. Numero intero 0
loraRanks I ranking LoRA da utilizzare dai modelli LoRA durante l'esecuzione. Nota: questa opzione è compatibile solo con i modelli GPU. Array di numeri interi N/D

Conversione del modello

L'API LLM Inference è compatibile con i seguenti tipi di modelli, alcuni dei quali richiedono la conversione del modello. Utilizza la tabella per identificare il metodo di procedura richiesto per il tuo modello.

Modelli Metodo di conversione Piattaforme compatibili Tipo di file
Gemma-3 1B Nessuna conversione richiesta Android, web .task
Gemma 2B, Gemma 7B, Gemma-2 2B Nessuna conversione richiesta Android, iOS, web .bin
Phi-2, StableLM, Falcon Script di conversione MediaPipe Android, iOS, web .bin
Tutti i modelli LLM di PyTorch AI Edge Torch Generative library Android, iOS .task

Per scoprire come convertire altri modelli, consulta la sezione Conversione del modello.

Personalizzazione LoRA

L'API Inference LLM supporta l'ottimizzazione LoRA (Low-Rank Adaptation) utilizzando la libreria PEFT (Parameter-Efficient Fine-Tuning). La regolazione LoRA personalizza il comportamento degli LLM tramite un procedura di addestramento economica, creando un piccolo insieme di pesi addestrabili basati su nuovi dati di addestramento anziché addestrare nuovamente l'intero modello.

L'API LLM Inference supporta l'aggiunta di pesi LoRA ai livelli di attenzione dei modelli Gemma-2 2B, Gemma 2B e Phi-2. Scarica il modello nel formato safetensors.

Il modello di base deve essere in formato safetensors per creare i pesi LoRA. Dopo l'addestramento LoRA, puoi convertire i modelli nel formato FlatBuffers per eseguirli su MediaPipe.

Prepara i pesi LoRA

Utilizza la guida Metodi LoRA di PEFT per addestrare un modello LoRA ottimizzato sul tuo set di dati.

L'API LLM Inference supporta LoRA solo nei livelli di attenzione, quindi specifica solo i livelli di attenzione in LoraConfig:

# For Gemma
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

Dopo l'addestramento sul set di dati preparato e il salvataggio del modello, i pesi del modello LoRA ottimizzato sono disponibili in adapter_model.safetensors. Il file safetensors è il checkpoint LoRA utilizzato durante la conversione del modello.

Conversione del modello

Utilizza il pacchetto Python MediaPipe per convertire i pesi del modello nel formato Flatbuffer. ConversionConfig specifica le opzioni del modello di base insieme alle opzioni LoRa aggiuntive.

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_FILE,
)

converter.convert_checkpoint(config)

Il convertitore produrrà due file compatibili con MediaPipe, uno per il modello di base e un altro per il modello LoRA.

Inferenza del modello LoRA

Il web supporta LoRA dinamico durante l'esecuzione, il che significa che gli utenti dichiarano i ranghi LoRA durante l'inizializzazione. Ciò significa che puoi sostituire diversi modelli LoRa durante il runtime.

const genai = await FilesetResolver.forGenAiTasks(
    // path/to/wasm/root
    "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai@latest/wasm"
);
const llmInference = await LlmInference.createFromOptions(genai, {
    // options for the base model
    ...
    // LoRA ranks to be used by the LoRA models during runtime
    loraRanks: [4, 8, 16]
});

Carica i modelli LoRa durante l'esecuzione, dopo aver inizializzato il modello di base. Attiva il modello LoRA passando il riferimento del modello durante la generazione della risposta LLM.

// Load several LoRA models. The returned LoRA model reference is used to specify
// which LoRA model to be used for inference.
loraModelRank4 = await llmInference.loadLoraModel(loraModelRank4Url);
loraModelRank8 = await llmInference.loadLoraModel(loraModelRank8Url);

// Specify LoRA model to be used during inference
llmInference.generateResponse(
  inputPrompt,
  loraModelRank4,
  (partialResult, done) => {
        document.getElementById('output').textContent += partialResult;
});