La API de LLM Inference te permite ejecutar modelos grandes de lenguaje (LLM) completamente en el dispositivo, que puedes usar para realizar una amplia variedad de tareas, como generar texto, recuperar información en forma de lenguaje natural y resumir documentos. La tarea proporciona compatibilidad integrada con varios modelos grandes de lenguaje de texto a texto, de modo que puedas aplicar los modelos de IA generativa más recientes en el dispositivo a tus apps y productos.
La tarea proporciona compatibilidad integrada para una variedad de LLM. Los modelos alojados en la página LiteRT Community están disponibles en un formato compatible con MediaPipe y no requieren ningún paso adicional de conversión o compilación.
Puedes usar AI Edge Torch para exportar modelos de PyTorch a modelos LiteRT (tflite
) de varias firmas, que se agrupan con parámetros de tokenizer para crear paquetes de tareas. Los modelos convertidos con AI Edge Torch son compatibles con la API de LLM Inference y pueden ejecutarse en el backend de la CPU, lo que los hace adecuados para aplicaciones para Android y iOS.
Comenzar
Para comenzar a usar esta tarea, sigue una de estas guías de implementación para tu plataforma de destino. En estas guías específicas de la plataforma, se explica una implementación básica de esta tarea, con ejemplos de código que usan un modelo disponible y las opciones de configuración recomendadas:
Web:
Android:
iOS
Detalles de la tarea
En esta sección, se describen las capacidades, las entradas, las salidas y las opciones de configuración de esta tarea.
Funciones
La API de LLM Inference contiene las siguientes funciones clave:
- Generación de texto a texto: Genera texto a partir de una instrucción de texto de entrada.
- Selección de LLM: Aplica varios modelos para adaptar la app a tus casos de uso específicos. También puedes volver a entrenar y aplicar pesos personalizados al modelo.
- Compatibilidad con LoRA: Extiende y personaliza la capacidad del LLM con el modelo LoRA entrenando todo tu conjunto de datos o tomando modelos LoRA precompilados preparados de la comunidad de código abierto (no es compatible con los modelos convertidos con la API generativa de AI Edge Torch).
Entradas de tareas | Resultados de las tareas |
---|---|
La API de LLM Inference acepta las siguientes entradas:
|
La API de LLM Inference genera los siguientes resultados:
|
Opciones de configuración
Esta tarea tiene las siguientes opciones de configuración:
Nombre de la opción | Descripción | Rango de valores | Valor predeterminado |
---|---|---|---|
modelPath |
Es la ruta de acceso a la ubicación en la que se almacena el modelo dentro del directorio del proyecto. | PATH | N/A |
maxTokens |
Es la cantidad máxima de tokens (tokens de entrada + tokens de salida) que controla el modelo. | Número entero | 512 |
topK |
Es la cantidad de tokens que considera el modelo en cada paso de generación. Limita las predicciones a los tokens más probables de Top-K. | Número entero | 40 |
temperature |
Es la cantidad de aleatoriedad que se introduce durante la generación. Una temperatura más alta genera más creatividad en el texto generado, mientras que una temperatura más baja produce una generación más predecible. | Número de punto flotante | 0.8 |
randomSeed |
Es la semilla aleatoria que se usa durante la generación de texto. | Número entero | 0 |
loraPath |
Es la ruta de acceso absoluta al modelo de LoRA de forma local en el dispositivo. Nota: Solo es compatible con modelos de GPU. | PATH | N/A |
resultListener |
Establece el objeto de escucha de resultados para que reciba los resultados de forma asíncrona. Solo se aplica cuando se usa el método de generación asíncrona. | N/A | N/A |
errorListener |
Establece un objeto de escucha de errores opcional. | N/A | N/A |
Modelos
La API de LLM Inference admite muchos modelos grandes de lenguaje de texto a texto, incluida la compatibilidad integrada con varios modelos que están optimizados para ejecutarse en navegadores y dispositivos móviles. Estos modelos ligeros se pueden usar para ejecutar inferencias completamente en el dispositivo.
Antes de inicializar la API de LLM Inference, descarga un modelo y almacena el archivo dentro del directorio de tu proyecto. Puedes usar un modelo convertido previamente del repositorio de HuggingFace de la LiteRT Community o convertir un modelo a un formato compatible con MediaPipe con el Conversor generativo de Torch de AI Edge.
Si aún no tienes un LLM para usar con la API de LLM Inference, comienza con uno de los siguientes modelos.
Gemma-3 1B
Gemma-3 1B es el modelo más reciente de la familia de modelos abiertos, ligeros y de vanguardia de Gemma, compilados a partir de la misma investigación y tecnología que se utilizaron para crear los modelos de Gemini. El modelo contiene 1,000 millones de parámetros y ponderaciones abiertas. La variante 1B es el modelo más liviano de la familia de Gemma, lo que lo hace ideal para muchos casos de uso en el dispositivo.
El modelo Gemma-3 1B de HuggingFace está disponible en el formato .task
y listo para usar con la API de LLM Inference para aplicaciones web y para Android.
Cuando ejecutes Gemma-3 1B con la API de LLM Inference, configura las siguientes opciones según corresponda:
preferredBackend
: Usa esta opción para elegir entre un backendCPU
oGPU
. Esta opción solo está disponible para Android.supportedLoraRanks
: La API de inferencia de LLM no se puede configurar para admitir la adaptación de bajo rango (LoRA) con el modelo Gemma-3 1B. No uses las opcionessupportedLoraRanks
niloraRanks
.maxTokens
: El valor demaxTokens
debe coincidir con el tamaño del contexto integrado en el modelo. Esto también se puede denominar caché de par clave-valor (KV) o longitud del contexto.numResponses
: Siempre debe ser 1. Esta opción solo está disponible para la Web.
Cuando se ejecuta Gemma-3 1B en aplicaciones web, la inicialización puede causar un bloqueo prolongado en el subproceso actual. Si es posible, siempre ejecuta el modelo desde un subproceso de trabajo.
Gemma-2 2B
Gemma-2 2B es una variante de 2B de Gemma-2 y funciona en todas las plataformas.
El modelo contiene 2,000 millones de parámetros y ponderaciones abiertas. Gemma-2 2B se conoce por sus habilidades de razonamiento de vanguardia para los modelos de su clase.
Conversión de modelos de PyTorch
Los modelos generativos de PyTorch se pueden convertir a un formato compatible con MediaPipe con la API de AI Edge Torch Generative. Puedes usar la API para convertir modelos de PyTorch en modelos LiteRT (TensorFlow Lite) con varias firmas. Para obtener más detalles sobre la asignación y exportación de modelos, visita la página de GitHub de AI Edge Torch.
La conversión de un modelo de PyTorch con la API generativa de AI Edge Torch implica los siguientes pasos:
- Descarga los puntos de control del modelo de PyTorch.
- Usa la API generativa de AI Edge Torch para crear, convertir y cuantificar el
modelo en un formato de archivo compatible con MediaPipe (
.tflite
). - Crea un paquete de tareas (
.task
) a partir del archivo tflite y el tokenizador del modelo.
El convertidor generativo de Torch solo realiza conversiones para CPU y requiere una máquina Linux con al menos 64 GB de RAM.
Para crear un paquete de tareas, usa la secuencia de comandos de empaquetado para crear un paquete de tareas. El proceso de agrupación empaqueta el modelo asignado con metadatos adicionales (p.ej., Parámetros del analizador) necesarios para ejecutar la inferencia de extremo a extremo.
El proceso de agrupación de modelos requiere el paquete PyPI de MediaPipe. La secuencia de comandos de conversión está disponible en todos los paquetes de MediaPipe después de 0.10.14
.
Instala y, luego, importa las dependencias con lo siguiente:
$ python3 -m pip install mediapipe
Usa la biblioteca genai.bundler
para empaquetar el modelo:
import mediapipe as mp
from mediapipe.tasks.python.genai import bundler
config = bundler.BundleConfig(
tflite_model=TFLITE_MODEL,
tokenizer_model=TOKENIZER_MODEL,
start_token=START_TOKEN,
stop_tokens=STOP_TOKENS,
output_filename=OUTPUT_FILENAME,
enable_bytes_to_unicode_mapping=ENABLE_BYTES_TO_UNICODE_MAPPING,
)
bundler.create_bundle(config)
Parámetro | Descripción | Valores aceptados |
---|---|---|
tflite_model |
Es la ruta de acceso al modelo de TFLite exportado de AI Edge. | PATH |
tokenizer_model |
Es la ruta de acceso al modelo del tokenizador SentencePiece. | PATH |
start_token |
Es un token de inicio específico del modelo. El token de inicio debe estar presente en el modelo de analizador proporcionado. | STRING |
stop_tokens |
Tokens de parada específicos del modelo Los tokens de parada deben estar presentes en el modelo de analizador proporcionado. | LIST[STRING] |
output_filename |
Es el nombre del archivo del paquete de tareas de salida. | PATH |
Personalización de LoRA
La API de inferencia de LLM de Mediapipe se puede configurar para admitir la adaptación de clasificación baja (LoRA) para modelos de lenguaje grande. Con modelos LoRA ajustados, los desarrolladores pueden personalizar el comportamiento de los LLM a través de un proceso de entrenamiento rentable.La compatibilidad con LoRA de la API de inferencia de LLM funciona para todas las variantes de Gemma y los modelos Phi-2 para el backend de GPU, con pesos de LoRA aplicables solo a las capas de atención. Esta implementación inicial sirve como una API experimental para futuros desarrollos con planes para admitir más modelos y varios tipos de capas en las próximas actualizaciones.
Prepara modelos de LoRA
Sigue las instrucciones en
HuggingFace
para entrenar un modelo LoRA ajustado en tu propio conjunto de datos con los tipos de modelos compatibles,
Gemma o Phi-2. Los modelos Gemma-2 2B, Gemma 2B y Phi-2 están disponibles en HuggingFace en el formato safetensors. Dado que la API de LLM Inference solo admite LoRA en las capas de atención, especifica solo las capas de atención cuando crees el LoraConfig
de la siguiente manera:
# For Gemma
from peft import LoraConfig
config = LoraConfig(
r=LORA_RANK,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)
# For Phi-2
config = LoraConfig(
r=LORA_RANK,
target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)
Para las pruebas, hay modelos LoRA ajustados de acceso público que se ajustan a la API de inferencia de LLM disponible en HuggingFace. Por ejemplo, monsterapi/gemma-2b-lora-maths-orca-200k para Gemma-2B y lole25/phi-2-sft-ultrachat-lora para Phi-2.
Después de entrenar el conjunto de datos preparado y guardar el modelo, obtienes un archivo adapter_model.safetensors
que contiene los pesos del modelo LoRA ajustados.
El archivo safetensors es el punto de control de LoRA que se usa en la conversión de modelos.
Como siguiente paso, debes convertir los pesos del modelo en un Flatbuffer de TensorFlow Lite con el paquete de Python de MediaPipe. ConversionConfig
debe especificar las opciones del modelo base, así como las opciones adicionales de LoRA. Ten en cuenta que, como la API solo admite la inferencia de LoRA con GPU, el backend debe establecerse en 'gpu'
.
import mediapipe as mp
from mediapipe.tasks.python.genai import converter
config = converter.ConversionConfig(
# Other params related to base model
...
# Must use gpu backend for LoRA conversion
backend='gpu',
# LoRA related params
lora_ckpt=LORA_CKPT,
lora_rank=LORA_RANK,
lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)
converter.convert_checkpoint(config)
El convertidor generará dos archivos FlatBuffer de TFLite, uno para el modelo base y el otro para el modelo LoRA.
Inferencia de modelos de LoRA
Se actualizaron las APIs de inferencia de LLM de la Web, Android y iOS para admitir la inferencia de modelos LoRA.
Android admite LoRA estático durante la inicialización. Para cargar un modelo de LoRA, los usuarios especifican la ruta de acceso del modelo de LoRA y el LLM base.// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
.setModelPath('<path to base model>')
.setMaxTokens(1000)
.setTopK(40)
.setTemperature(0.8)
.setRandomSeed(101)
.setLoraPath('<path to LoRA model>')
.build()
// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)
Para ejecutar la inferencia de LLM con LoRA, usa los mismos métodos generateResponse()
o generateResponseAsync()
que el modelo base.