Guía de inferencia de LLM para la Web

La API de inferencia de LLM te permite ejecutar modelos grandes de lenguaje (LLM) por completo en el navegador para aplicaciones web, que puedes usar para realizar una amplia variedad de tareas, como generar texto, recuperar información en forma de lenguaje natural y resumir documentos. La tarea proporciona compatibilidad integrada con varios modelos grandes de lenguaje de texto a texto, de modo que puedas aplicar los modelos de IA generativa más recientes en el dispositivo a tus apps web.

La tarea admite las siguientes variantes de Gemma: Gemma-2 2B, Gemma 2B y Gemma 7B. Gemma es una familia de modelos abiertos, livianos y de vanguardia creados a partir de la misma investigación y tecnología que se utilizaron para crear los modelos de Gemini. También admite los siguientes modelos externos: Phi-2, Falcon-RW-1B y StableLM-3B.

Puedes ver esta tarea en acción con la demo de MediaPipe Studio. Para obtener más información sobre las funciones, los modelos y las opciones de configuración de esta tarea, consulta la descripción general.

Ejemplo de código

La aplicación de ejemplo para la API de inferencia de LLM proporciona una implementación básica de esta tarea en JavaScript como referencia. Puedes usar esta app de ejemplo para comenzar a compilar tu propia app de generación de texto.

Puedes acceder a la app de ejemplo de la API de LLM Inference en GitHub.

Configuración

En esta sección, se describen los pasos clave para configurar tu entorno de desarrollo y proyectos de código específicamente para usar la API de LLM Inference. Si deseas obtener información general sobre cómo configurar tu entorno de desarrollo para usar MediaPipe Tasks, incluidos los requisitos de versión de la plataforma, consulta la Guía de configuración para la Web.

Compatibilidad del navegador

La API de LLM Inference requiere un navegador web con compatibilidad con WebGPU. Para obtener una lista completa de los navegadores compatibles, consulta Compatibilidad de navegadores con GPUs.

Paquetes de JavaScript

El código de la API de inferencia de LLM está disponible a través del paquete @mediapipe/tasks-genai. Puedes encontrar y descargar estas bibliotecas desde los vínculos que se proporcionan en la guía de configuración de la plataforma.

Instala los paquetes necesarios para la etapa de pruebas local:

npm install @mediapipe/tasks-genai

Para implementarlo en un servidor, usa un servicio de red de distribución de contenidos (CDN) como jsDelivr para agregar código directamente a tu página HTML:

<head>
  <script src="https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai/genai_bundle.cjs"
    crossorigin="anonymous"></script>
</head>

Modelo

La API de MediaPipe LLM Inference requiere un modelo entrenado que sea compatible con esta tarea. En el caso de las aplicaciones web, el modelo debe ser compatible con la GPU.

Para obtener más información sobre los modelos entrenados disponibles para la API de LLM Inference, consulta la sección Modelos de la descripción general de la tarea.

Descarga un modelo

Antes de inicializar la API de LLM Inference, descarga uno de los modelos compatibles y almacena el archivo en el directorio de tu proyecto:

  • Gemma-2 2B: Esta es la versión más reciente de la familia de modelos Gemma. Forma parte de una familia de modelos abiertos ligeros y de vanguardia creados a partir de la misma investigación y tecnología que se usaron para crear los modelos de Gemini.
  • Gemma 2B: Forma parte de una familia de modelos abiertos, ligeros y de vanguardia creados a partir de la misma investigación y tecnología que se usaron para crear los modelos de Gemini. Es adecuado para una variedad de tareas de generación de texto, como la respuesta a preguntas, el resumen y el razonamiento.
  • Phi-2: Es un modelo Transformer de 2, 700 millones de parámetros, que es más adecuado para el formato de pregunta y respuesta, chat y código.
  • Falcon-RW-1B: Es un modelo de solo decodificador causal de 1,000 millones de parámetros entrenado en 350,000 millones de tokens de RefinedWeb.
  • StableLM-3B: Es un modelo de lenguaje de solo decodificador de 3,000 millones de parámetros que se entrenó previamente en 1 billón de tokens de diversos conjuntos de datos de código y en inglés.

Además de los modelos compatibles, puedes usar AI Edge Torch de Google para exportar modelos de PyTorch a modelos LiteRT (tflite) de varias firmas. Para obtener más información, consulta Convertidor generativo de Torch para modelos de PyTorch.

Te recomendamos que uses Gemma-2 2B, que está disponible en Kaggle Models. Para obtener más información sobre los otros modelos disponibles, consulta la sección Modelos de la descripción general de la tarea.

Convierte el modelo al formato MediaPipe

La API de inferencia de LLM es compatible con dos tipos de categorías de modelos, algunos de los cuales requieren conversión de modelos. Usa la tabla para identificar el método de pasos necesarios para tu modelo.

Modelos Método de conversión Plataformas compatibles File type
Modelos compatibles Gemma 2B, Gemma 7B, Gemma-2 2B, Phi-2, StableLM y Falcon MediaPipe Android, iOS y Web .bin
Otros modelos de PyTorch Todos los modelos de LLM de PyTorch Biblioteca generativa de AI Edge Torch Android, iOS .task

Almacenamos los archivos .bin convertidos para Gemma 2B, Gemma 7B y Gemma-2 2B en Kaggle. Estos modelos se pueden implementar directamente con nuestra API de inferencia de LLM. Para obtener información sobre cómo convertir otros modelos, consulta la sección Conversión de modelos.

Agrega el modelo al directorio del proyecto

Almacena el modelo en el directorio de tu proyecto:

<dev-project-root>/assets/gemma-2b-it-gpu-int4.bin

Especifica la ruta del modelo con el parámetro modelAssetPath del objeto baseOptions:

baseOptions: { modelAssetPath: `/assets/gemma-2b-it-gpu-int4.bin`}

Crea la tarea

Usa una de las funciones createFrom...() de la API de LLM Inference para preparar la tarea para ejecutar inferencias. Puedes usar la función createFromModelPath() con una ruta de acceso relativa o absoluta al archivo del modelo entrenado. En el ejemplo de código, se usa la función createFromOptions(). Para obtener más información sobre las opciones de configuración disponibles, consulta Opciones de configuración.

En el siguiente código, se muestra cómo compilar y configurar esta tarea:

const genai = await FilesetResolver.forGenAiTasks(
    // path/to/wasm/root
    "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai@latest/wasm"
);
llmInference = await LlmInference.createFromOptions(genai, {
    baseOptions: {
        modelAssetPath: '/assets/gemma-2b-it-gpu-int4.bin'
    },
    maxTokens: 1000,
    topK: 40,
    temperature: 0.8,
    randomSeed: 101
});

Opciones de configuración

Esta tarea tiene las siguientes opciones de configuración para apps web y de JavaScript:

Nombre de la opción Descripción Rango de valores Valor predeterminado
modelPath Es la ruta de acceso a la ubicación en la que se almacena el modelo dentro del directorio del proyecto. PATH N/A
maxTokens Es la cantidad máxima de tokens (tokens de entrada + tokens de salida) que controla el modelo. Número entero 512
topK Es la cantidad de tokens que considera el modelo en cada paso de generación. Limita las predicciones a los tokens más probables de Top-K. Número entero 40
temperature Es la cantidad de aleatoriedad que se introduce durante la generación. Una temperatura más alta genera más creatividad en el texto generado, mientras que una temperatura más baja produce una generación más predecible. Número de punto flotante 0.8
randomSeed Es la semilla aleatoria que se usa durante la generación de texto. Número entero 0
loraRanks Son las clasificaciones de LoRA que usarán los modelos de LoRA durante el tiempo de ejecución. Nota: Esta opción solo es compatible con modelos de GPU. Array de números enteros N/A

Preparar los datos

La API de LLM Inference acepta datos de texto (string). La tarea controla el procesamiento previo de la entrada de datos, incluida la tokenización y el procesamiento previo de tensores.

Todo el procesamiento previo se controla dentro de la función generateResponse(). No es necesario realizar un procesamiento previo adicional del texto de entrada.

const inputPrompt = "Compose an email to remind Brett of lunch plans at noon on Saturday.";

Ejecuta la tarea

La API de LLM Inference usa la función generateResponse() para activar inferencias. En el caso de la clasificación de texto, esto significa mostrar las categorías posibles para el texto de entrada.

En el siguiente código, se muestra cómo ejecutar el procesamiento con el modelo de tareas.

const response = await llmInference.generateResponse(inputPrompt);
document.getElementById('output').textContent = response;

Para transmitir la respuesta, usa lo siguiente:

llmInference.generateResponse(
  inputPrompt,
  (partialResult, done) => {
        document.getElementById('output').textContent += partialResult;
});

Cómo controlar y mostrar los resultados

La API de LLM Inference devuelve una cadena, que incluye el texto de la respuesta generada.

Here's a draft you can use:

Subject: Lunch on Saturday Reminder

Hi Brett,

Just a quick reminder about our lunch plans this Saturday at noon.
Let me know if that still works for you.

Looking forward to it!

Best,
[Your Name]

Personalización de modelos LoRA

La API de inferencia de LLM de Mediapipe se puede configurar para admitir la adaptación de clasificación baja (LoRA) para modelos de lenguaje grande. Con modelos LoRA ajustados, los desarrolladores pueden personalizar el comportamiento de los LLM a través de un proceso de entrenamiento rentable.

La compatibilidad con LoRA de la API de LLM Inference funciona para todas las variantes de Gemma y los modelos Phi-2 para el backend de GPU, con pesos de LoRA aplicables solo a las capas de atención. Esta implementación inicial funciona como una API experimental para futuros desarrollos con planes para admitir más modelos y varios tipos de capas en las próximas actualizaciones.

Prepara modelos de LoRA

Sigue las instrucciones en HuggingFace para entrenar un modelo LoRA ajustado en tu propio conjunto de datos con los tipos de modelos compatibles, Gemma o Phi-2. Los modelos Gemma-2 2B, Gemma 2B y Phi-2 están disponibles en HuggingFace en el formato safetensors. Dado que la API de LLM Inference solo admite LoRA en capas de atención, especifica solo capas de atención cuando crees el LoraConfig de la siguiente manera:

# For Gemma
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

Para las pruebas, hay modelos LoRA ajustados de acceso público que se ajustan a la API de inferencia de LLM disponible en HuggingFace. Por ejemplo, monsterapi/gemma-2b-lora-maths-orca-200k para Gemma-2B y lole25/phi-2-sft-ultrachat-lora para Phi-2.

Después de entrenar el conjunto de datos preparado y guardar el modelo, obtienes un archivo adapter_model.safetensors que contiene los pesos del modelo LoRA ajustados. El archivo safetensors es el punto de control de LoRA que se usa en la conversión de modelos.

Como siguiente paso, debes convertir los pesos del modelo en un Flatbuffer de TensorFlow Lite con el paquete de Python de MediaPipe. ConversionConfig debe especificar las opciones del modelo base, así como las opciones adicionales de LoRA. Ten en cuenta que, como la API solo admite la inferencia de LoRA con GPU, el backend debe establecerse en 'gpu'.

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

El convertidor generará dos archivos FlatBuffer de TFLite, uno para el modelo base y el otro para el modelo LoRA.

Inferencia de modelos de LoRA

Se actualizaron las APIs de inferencia de LLM de la Web, Android y iOS para admitir la inferencia de modelos LoRA.

La Web admite LoRA dinámico durante el tiempo de ejecución. Es decir, los usuarios declaran las clasificaciones de LoRA que se usarán durante la inicialización y pueden intercambiar diferentes modelos de LoRA durante el tiempo de ejecución.

const genai = await FilesetResolver.forGenAiTasks(
    // path/to/wasm/root
    "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai@latest/wasm"
);
const llmInference = await LlmInference.createFromOptions(genai, {
    // options for the base model
    ...
    // LoRA ranks to be used by the LoRA models during runtime
    loraRanks: [4, 8, 16]
});

Durante el tiempo de ejecución, después de inicializar el modelo base, carga los modelos de LoRA que se usarán. Además, pasa la referencia del modelo LoRA mientras generas la respuesta de LLM para activarlo.

// Load several LoRA models. The returned LoRA model reference is used to specify
// which LoRA model to be used for inference.
loraModelRank4 = await llmInference.loadLoraModel(loraModelRank4Url);
loraModelRank8 = await llmInference.loadLoraModel(loraModelRank8Url);

// Specify LoRA model to be used during inference
llmInference.generateResponse(
  inputPrompt,
  loraModelRank4,
  (partialResult, done) => {
        document.getElementById('output').textContent += partialResult;
});