La API de LLM Inference te permite ejecutar modelos de lenguaje grandes (LLM) completamente en el dispositivo para aplicaciones para Android, que puedes usar para realizar una amplia variedad de tareas, como generar texto, recuperar información en forma de lenguaje natural y resumir documentos. La tarea proporciona compatibilidad integrada con varios modelos de lenguaje grande de texto a texto, de modo que puedas aplicar los modelos de IA generativa más recientes en el dispositivo a tus apps para Android.
La tarea admite las siguientes variantes de Gemma: Gemma-2 2B, Gemma 2B y Gemma 7B. Gemma es una familia de modelos abiertos, livianos y de vanguardia creados a partir de la misma investigación y tecnología que se utilizaron para crear los modelos de Gemini. También admite los siguientes modelos externos: Phi-2, Falcon-RW-1B y StableLM-3B.
Además de los modelos admitidos, los usuarios pueden usar AI Edge Torch de Google para exportar modelos de PyTorch a modelos LiteRT (tflite
) de varias firmas, que se agrupan con parámetros de tokenizador para crear paquetes de tareas compatibles con la API de inferencia de LLM.
Puedes ver esta tarea en acción con la demo de MediaPipe Studio. Para obtener más información sobre las funciones, los modelos y las opciones de configuración de esta tarea, consulta la descripción general.
Ejemplo de código
En esta guía, se hace referencia a un ejemplo de una app básica de generación de texto para Android. Puedes usar la app como punto de partida para tu propia app para Android o consultarla cuando modifiques una app existente. El código de ejemplo se aloja en GitHub.
Descarga el código
En las siguientes instrucciones, se muestra cómo crear una copia local del código de ejemplo con la herramienta de línea de comandos git.
Para descargar el código de ejemplo, sigue estos pasos:
- Clona el repositorio de git con el siguiente comando:
git clone https://github.com/google-ai-edge/mediapipe-samples
- De manera opcional, configura tu instancia de git para usar el control de revisión disperso, de modo que solo tengas los archivos de la app de ejemplo de la API de LLM Inference:
cd mediapipe git sparse-checkout init --cone git sparse-checkout set examples/llm_inference/android
Después de crear una versión local del código de ejemplo, puedes importar el proyecto a Android Studio y ejecutar la app. Para obtener instrucciones, consulta la Guía de configuración para Android.
Configuración
En esta sección, se describen los pasos clave para configurar tu entorno de desarrollo y proyectos de código específicamente para usar la API de LLM Inference. Si deseas obtener información general sobre cómo configurar tu entorno de desarrollo para usar tareas de MediaPipe, incluidos los requisitos de versión de la plataforma, consulta la Guía de configuración para Android.
Dependencias
La API de inferencia de LLM usa la biblioteca com.google.mediapipe:tasks-genai
. Agrega esta dependencia al archivo build.gradle
de tu app para Android:
dependencies {
implementation 'com.google.mediapipe:tasks-genai:0.10.14'
}
En el caso de los dispositivos con Android 12 (nivel de API 31) o versiones posteriores, agrega la dependencia de la biblioteca nativa de OpenCL. Para obtener más información, consulta la documentación sobre la etiqueta uses-native-library
.
Agrega las siguientes etiquetas uses-native-library
al archivo AndroidManifest.xml
:
<uses-native-library android:name="libOpenCL.so" android:required="false"/>
<uses-native-library android:name="libOpenCL-car.so" android:required="false"/>
<uses-native-library android:name="libOpenCL-pixel.so" android:required="false"/>
Modelo
La API de inferencia de LLM de MediaPipe requiere un modelo de lenguaje de texto a texto entrenado que sea compatible con esta tarea. Después de descargar un modelo, instala las dependencias requeridas y envía el modelo al dispositivo Android. Si usas un modelo que no sea Gemma, deberás convertirlo a un formato compatible con MediaPipe.
Para obtener más información sobre los modelos entrenados disponibles para la API de LLM Inference, consulta la sección Modelos de la descripción general de la tarea.
Descarga un modelo
Antes de inicializar la API de LLM Inference, descarga uno de los modelos compatibles y almacena el archivo en el directorio de tu proyecto:
- Gemma-2 2B: Esta es la versión más reciente de la familia de modelos Gemma. Forma parte de una familia de modelos abiertos ligeros y de vanguardia creados a partir de la misma investigación y tecnología que se usaron para crear los modelos de Gemini.
- Gemma 2B: Forma parte de una familia de modelos abiertos, ligeros y de vanguardia creados a partir de la misma investigación y tecnología que se usaron para crear los modelos de Gemini. Es adecuado para una variedad de tareas de generación de texto, como la respuesta a preguntas, el resumen y el razonamiento.
- Phi-2: Es un modelo Transformer de 2, 700 millones de parámetros, que es más adecuado para el formato de pregunta y respuesta, chat y código.
- Falcon-RW-1B: Es un modelo de solo decodificador causal de 1,000 millones de parámetros entrenado en 350,000 millones de tokens de RefinedWeb.
- StableLM-3B: Es un modelo de lenguaje de solo decodificador de 3,000 millones de parámetros que se entrenó previamente en 1 billón de tokens de diversos conjuntos de datos de código y en inglés.
Además de los modelos compatibles, puedes usar AI Edge Torch de Google para exportar modelos de PyTorch a modelos LiteRT (tflite
) de varias firmas. Para obtener más información, consulta
Convertidor generativo de Torch para modelos de PyTorch.
Te recomendamos que uses Gemma-2 2B, que está disponible en Kaggle Models. Para obtener más información sobre los otros modelos disponibles, consulta la sección Modelos de la descripción general de la tarea.
Convierte el modelo al formato MediaPipe
La API de inferencia de LLM es compatible con dos tipos de categorías de modelos, algunos de los cuales requieren conversión de modelos. Usa la tabla para identificar el método de pasos necesarios para tu modelo.
Modelos | Método de conversión | Plataformas compatibles | File type | |
---|---|---|---|---|
Modelos compatibles | Gemma 2B, Gemma 7B, Gemma-2 2B, Phi-2, StableLM y Falcon | MediaPipe | Android, iOS y Web | .bin |
Otros modelos de PyTorch | Todos los modelos de LLM de PyTorch | Biblioteca generativa de AI Edge Torch | Android, iOS | .task |
Almacenamos los archivos .bin
convertidos para Gemma 2B, Gemma 7B y Gemma-2 2B
en Kaggle. Estos modelos se pueden implementar directamente con nuestra API de inferencia de LLM. Para obtener información sobre cómo convertir otros modelos, consulta la sección Conversión de modelos.
Envía el modelo al dispositivo
Envía el contenido de la carpeta output_path al dispositivo Android.
$ adb shell rm -r /data/local/tmp/llm/ # Remove any previously loaded models
$ adb shell mkdir -p /data/local/tmp/llm/
$ adb push output_path /data/local/tmp/llm/model_version.bin
Crea la tarea
La API de MediaPipe LLM Inference usa la función createFromOptions()
para configurar la tarea. La función createFromOptions()
acepta valores para las opciones de configuración. Para obtener más información sobre las opciones de configuración, consulta Opciones de configuración.
El siguiente código inicializa la tarea con opciones de configuración básicas:
// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
.setModelPATH('/data/local/.../')
.setMaxTokens(1000)
.setTopK(40)
.setTemperature(0.8)
.setRandomSeed(101)
.build()
// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)
Opciones de configuración
Usa las siguientes opciones de configuración para configurar una app para Android:
Nombre de la opción | Descripción | Rango de valores | Valor predeterminado |
---|---|---|---|
modelPath |
Es la ruta de acceso a la ubicación en la que se almacena el modelo dentro del directorio del proyecto. | PATH | N/A |
maxTokens |
Es la cantidad máxima de tokens (tokens de entrada + tokens de salida) que controla el modelo. | Número entero | 512 |
topK |
Es la cantidad de tokens que considera el modelo en cada paso de generación. Limita las predicciones a los tokens más probables de Top-K. | Número entero | 40 |
temperature |
Es la cantidad de aleatoriedad que se introduce durante la generación. Una temperatura más alta genera más creatividad en el texto generado, mientras que una temperatura más baja produce una generación más predecible. | Número de punto flotante | 0.8 |
randomSeed |
Es la semilla aleatoria que se usa durante la generación de texto. | Número entero | 0 |
loraPath |
Es la ruta de acceso absoluta al modelo de LoRA de forma local en el dispositivo. Nota: Esta opción solo es compatible con modelos de GPU. | PATH | N/A |
resultListener |
Establece el objeto de escucha de resultados para que reciba los resultados de forma asíncrona. Solo se aplica cuando se usa el método de generación asíncrona. | N/A | N/A |
errorListener |
Establece un objeto de escucha de errores opcional. | N/A | N/A |
Preparar los datos
La API de LLM Inference acepta las siguientes entradas:
- prompt (cadena): Es una pregunta o instrucción.
val inputPrompt = "Compose an email to remind Brett of lunch plans at noon on Saturday."
Ejecuta la tarea
Usa el método generateResponse()
para generar una respuesta de texto a la entrada de texto proporcionada en la sección anterior (inputPrompt
). Esto produce una sola respuesta generada.
val result = llmInference.generateResponse(inputPrompt)
logger.atInfo().log("result: $result")
Para transmitir la respuesta, usa el método generateResponseAsync()
.
val options = LlmInference.LlmInferenceOptions.builder()
...
.setResultListener { partialResult, done ->
logger.atInfo().log("partial result: $partialResult")
}
.build()
llmInference.generateResponseAsync(inputPrompt)
Cómo controlar y mostrar los resultados
La API de LLM Inference muestra un LlmInferenceResult
, que incluye el texto de la respuesta generada.
Here's a draft you can use:
Subject: Lunch on Saturday Reminder
Hi Brett,
Just a quick reminder about our lunch plans this Saturday at noon.
Let me know if that still works for you.
Looking forward to it!
Best,
[Your Name]
Personalización de modelos LoRA
La API de inferencia de LLM de Mediapipe se puede configurar para admitir la adaptación de clasificación baja (LoRA) para modelos de lenguaje grande. Con modelos LoRA ajustados, los desarrolladores pueden personalizar el comportamiento de los LLM a través de un proceso de entrenamiento rentable.
La compatibilidad con LoRA de la API de LLM Inference funciona para todas las variantes de Gemma y los modelos Phi-2 para el backend de GPU, con pesos de LoRA aplicables solo a las capas de atención. Esta implementación inicial funciona como una API experimental para futuros desarrollos con planes para admitir más modelos y varios tipos de capas en las próximas actualizaciones.
Prepara modelos de LoRA
Sigue las instrucciones en
HuggingFace
para entrenar un modelo LoRA ajustado en tu propio conjunto de datos con los tipos de modelos compatibles,
Gemma o Phi-2. Los modelos Gemma-2 2B, Gemma 2B y Phi-2 están disponibles en HuggingFace en el formato safetensors. Dado que la API de LLM Inference solo admite LoRA en capas de atención, especifica solo capas de atención cuando crees el LoraConfig
de la siguiente manera:
# For Gemma
from peft import LoraConfig
config = LoraConfig(
r=LORA_RANK,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)
# For Phi-2
config = LoraConfig(
r=LORA_RANK,
target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)
Para las pruebas, hay modelos LoRA ajustados de acceso público que se ajustan a la API de inferencia de LLM disponible en HuggingFace. Por ejemplo, monsterapi/gemma-2b-lora-maths-orca-200k para Gemma-2B y lole25/phi-2-sft-ultrachat-lora para Phi-2.
Después de entrenar el conjunto de datos preparado y guardar el modelo, obtienes un archivo adapter_model.safetensors
que contiene los pesos del modelo LoRA ajustados.
El archivo safetensors es el punto de control de LoRA que se usa en la conversión de modelos.
Como siguiente paso, debes convertir los pesos del modelo en un Flatbuffer de TensorFlow Lite con el paquete de Python de MediaPipe. ConversionConfig
debe especificar las opciones del modelo base, así como las opciones adicionales de LoRA. Ten en cuenta que, como la API solo admite la inferencia de LoRA con GPU, el backend debe establecerse en 'gpu'
.
import mediapipe as mp
from mediapipe.tasks.python.genai import converter
config = converter.ConversionConfig(
# Other params related to base model
...
# Must use gpu backend for LoRA conversion
backend='gpu',
# LoRA related params
lora_ckpt=LORA_CKPT,
lora_rank=LORA_RANK,
lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)
converter.convert_checkpoint(config)
El convertidor generará dos archivos FlatBuffer de TFLite, uno para el modelo base y el otro para el modelo LoRA.
Inferencia de modelos de LoRA
Se actualizaron las APIs de inferencia de LLM de la Web, Android y iOS para admitir la inferencia de modelos LoRA.
Android admite LoRA estático durante la inicialización. Para cargar un modelo de LoRA, los usuarios especifican la ruta de acceso del modelo de LoRA y el LLM base.// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
.setModelPath('<path to base model>')
.setMaxTokens(1000)
.setTopK(40)
.setTemperature(0.8)
.setRandomSeed(101)
.setLoraPath('<path to LoRA model>')
.build()
// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)
Para ejecutar la inferencia de LLM con LoRA, usa los mismos métodos generateResponse()
o generateResponseAsync()
que el modelo base.