Guia de inferência de LLM para Android

A API LLM Inference permite executar modelos de linguagem grandes (LLMs) totalmente no dispositivo para aplicativos Android, que podem ser usados para realizar uma ampla gama de tarefas, como gerar texto, recuperar informações em linguagem natural e resumir documentos. A tarefa oferece suporte integrado a vários modelos de linguagem grandes de texto para texto. Assim, você pode aplicar os modelos de IA generativa mais recentes no dispositivo aos seus apps Android.

A tarefa é compatível com as seguintes variantes do Gemma: Gemma 2 2B, Gemma 2B e Gemma 7B. O Gemma é uma família de modelos abertos leves e de última geração criados com base na mesma pesquisa e tecnologia usadas para criar os modelos do Gemini. Ele também oferece suporte aos seguintes modelos externos: Phi-2, Falcon-RW-1B e StableLM-3B.

Além dos modelos compatíveis, os usuários podem usar o AI Edge Torch do Google para exportar modelos PyTorch para modelos LiteRT (tflite) com várias assinaturas, que são agrupados com parâmetros de tokenizer para criar pacotes de tarefas compatíveis com a API Inference LLM.

Confira essa tarefa em ação com a demonstração do MediaPipe Studio. Para mais informações sobre os recursos, modelos e opções de configuração dessa tarefa, consulte a Visão geral.

Exemplo de código

Este guia se refere a um exemplo de app básico de geração de texto para Android. Você pode usar o app como ponto de partida para seu próprio app Android ou fazer referência a ele ao modificar um app existente. O código de exemplo está hospedado no GitHub.

Fazer o download do código

As instruções a seguir mostram como criar uma cópia local do código de exemplo usando a ferramenta de linha de comando git.

Para fazer o download do código de exemplo:

  1. Clone o repositório do Git usando o seguinte comando:
    git clone https://github.com/google-ai-edge/mediapipe-samples
    
  2. Opcionalmente, configure sua instância do Git para usar o checkout esparso, para que você tenha apenas os arquivos do app de exemplo da API LLM Inference:
    cd mediapipe
    git sparse-checkout init --cone
    git sparse-checkout set examples/llm_inference/android
    

Depois de criar uma versão local do código de exemplo, você pode importar o projeto para o Android Studio e executar o app. Para ver instruções, consulte o Guia de configuração para Android.

Configuração

Esta seção descreve as principais etapas para configurar seu ambiente de desenvolvimento e projetos de código especificamente para usar a API de inferência de LLM. Para informações gerais sobre como configurar seu ambiente de desenvolvimento para usar as tarefas do MediaPipe, incluindo os requisitos da versão da plataforma, consulte o Guia de configuração para Android.

Dependências

A API LLM Inference usa a biblioteca com.google.mediapipe:tasks-genai. Adicione essa dependência ao arquivo build.gradle do app Android:

dependencies {
    implementation 'com.google.mediapipe:tasks-genai:0.10.14'
}

Para dispositivos com o Android 12 (API 31) ou mais recente, adicione a dependência da biblioteca OpenCL nativa. Para mais informações, consulte a documentação da tag uses-native-library.

Adicione as seguintes tags uses-native-library ao arquivo AndroidManifest.xml:

<uses-native-library android:name="libOpenCL.so" android:required="false"/>
<uses-native-library android:name="libOpenCL-car.so" android:required="false"/>
<uses-native-library android:name="libOpenCL-pixel.so" android:required="false"/>

Modelo

A API MediaPipe LLM Inference requer um modelo de linguagem de conversão de texto em texto treinado que seja compatível com essa tarefa. Depois de fazer o download de um modelo, instale as dependências necessárias e envie o modelo para o dispositivo Android. Se você estiver usando um modelo diferente do Gemma, será necessário convertê-lo para um formato compatível com o MediaPipe.

Para mais informações sobre os modelos treinados disponíveis para a API de inferência LLM, consulte a seção "Modelos" da visão geral da tarefa.

Fazer o download de um modelo

Antes de inicializar a API de inferência de LLM, faça o download de um dos modelos compatíveis e armazene o arquivo no diretório do projeto:

  • Gemma-2 2B: a versão mais recente da família de modelos Gemma. Faz parte de uma família de modelos abertos leves e de última geração criados com a mesma pesquisa e tecnologia usadas para criar os modelos do Gemini.
  • Gemma 2B: faz parte de uma família de modelos abertos leves e de última geração criados com base na mesma pesquisa e tecnologia usadas para criar os modelos do Gemini. É adequado para várias tarefas de geração de texto, incluindo respostas a perguntas, resumo e raciocínio.
  • Phi-2: modelo Transformer de 2, 7 bilhões de parâmetros, mais adequado para o formato de perguntas e respostas, chat e código.
  • Falcon-RW-1B: modelo de 1 bilhão de parâmetros somente para decodificador treinado com 350 bilhões de tokens do RefinedWeb.
  • StableLM-3B: modelo de linguagem de apenas decodificador de parâmetros de 3 bilhões pré-treinado em 1 trilhão de tokens de diversos conjuntos de dados de inglês e código.

Além dos modelos compatíveis, você pode usar o AI Edge Torch do Google para exportar modelos do PyTorch para modelos LiteRT (tflite) com várias assinaturas. Para mais informações, consulte Conversor generativo do Torch para modelos PyTorch.

Recomendamos usar o Gemma-2 2B, que está disponível nos modelos do Kaggle. Para mais informações sobre os outros modelos disponíveis, consulte a seção "Modelos" da visão geral da tarefa.

Converter o modelo para o formato MediaPipe

A API de inferência de LLM é compatível com dois tipos de modelos de categorias, alguns dos quais exigem conversão de modelos. Use a tabela para identificar o método de etapas necessário para seu modelo.

Modelos Método de conversão Plataformas compatíveis Tipo de arquivo
Modelos compatíveis Gemma 2B, Gemma 7B, Gemma-2 2B, Phi-2, StableLM, Falcon MediaPipe Android, iOS e Web .bin
Outros modelos do PyTorch Todos os modelos LLM do PyTorch Biblioteca generativa AI Edge Torch Android, iOS .task

Os arquivos .bin convertidos para Gemma 2B, Gemma 7B e Gemma-2 2B estão hospedados no Kaggle. Esses modelos podem ser implantados diretamente usando nossa API de inferência de LLM. Para aprender a converter outros modelos, consulte a seção Conversão de modelos.

Enviar o modelo para o dispositivo

Envie o conteúdo da pasta output_path para o dispositivo Android.

$ adb shell rm -r /data/local/tmp/llm/ # Remove any previously loaded models
$ adb shell mkdir -p /data/local/tmp/llm/
$ adb push output_path /data/local/tmp/llm/model_version.bin

Criar a tarefa

A API de inferência LLM do MediaPipe usa a função createFromOptions() para configurar a tarefa. A função createFromOptions() aceita valores para as opções de configuração. Para mais informações sobre as opções de configuração, consulte Opções de configuração.

O código a seguir inicializa a tarefa usando opções de configuração básicas:

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPATH('/data/local/.../')
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

Opções de configuração

Use as opções de configuração a seguir para configurar um app Android:

Nome da opção Descrição Intervalo de valor Valor padrão
modelPath O caminho para onde o modelo é armazenado no diretório do projeto. CAMINHO N/A
maxTokens O número máximo de tokens (tokens de entrada + tokens de saída) que o modelo processa. Número inteiro 512
topK O número de tokens que o modelo considera em cada etapa de geração. Limita as previsões aos k tokens mais prováveis. Número inteiro 40
temperature A quantidade de aleatoriedade introduzida durante a geração. Uma temperatura mais alta resulta em mais criatividade no texto gerado, enquanto uma temperatura mais baixa produz uma geração mais previsível. Ponto flutuante 0,8
randomSeed A semente aleatória usada durante a geração de texto. Número inteiro 0
loraPath O caminho absoluto para o modelo LoRA localmente no dispositivo. Observação: isso só é compatível com modelos de GPU. CAMINHO N/A
resultListener Define o listener de resultado para receber os resultados de forma assíncrona. Aplicável apenas ao usar o método de geração assíncrona. N/A N/A
errorListener Define um listener de erro opcional. N/A N/A

Preparar dados

A API LLM Inference aceita as seguintes entradas:

  • prompt (string): uma pergunta ou comando.
val inputPrompt = "Compose an email to remind Brett of lunch plans at noon on Saturday."

Executar a tarefa

Use o método generateResponse() para gerar uma resposta de texto ao texto de entrada fornecido na seção anterior (inputPrompt). Isso produz uma única resposta gerada.

val result = llmInference.generateResponse(inputPrompt)
logger.atInfo().log("result: $result")

Para fazer streaming da resposta, use o método generateResponseAsync().

val options = LlmInference.LlmInferenceOptions.builder()
  ...
  .setResultListener { partialResult, done ->
    logger.atInfo().log("partial result: $partialResult")
  }
  .build()

llmInference.generateResponseAsync(inputPrompt)

Processar e mostrar resultados

A API de inferência de LLM retorna um LlmInferenceResult, que inclui o texto de resposta gerado.

Here's a draft you can use:

Subject: Lunch on Saturday Reminder

Hi Brett,

Just a quick reminder about our lunch plans this Saturday at noon.
Let me know if that still works for you.

Looking forward to it!

Best,
[Your Name]

Personalização de modelo LoRA

A API de inferência de LLM do Mediapipe pode ser configurada para oferecer suporte à adaptação de baixa classificação (LoRA) para modelos de linguagem grandes. Usando modelos LoRA ajustados, os desenvolvedores podem personalizar o comportamento de LLMs com um processo de treinamento econômico.

O suporte da API de inferência de LLM à LoRA funciona para todas as variantes do Gemma e modelos Phi-2 para o back-end da GPU, com pesos da LoRA aplicáveis apenas a camadas de atenção. Essa implementação inicial serve como uma API experimental para futuros desenvolvimentos com planos de oferecer suporte a mais modelos e vários tipos de camadas nas próximas atualizações.

Preparar modelos LoRA

Siga as instruções no HuggingFace para treinar um modelo LoRA ajustado no seu próprio conjunto de dados com os tipos de modelo compatíveis, Gemma ou Phi-2. Os modelos Gemma-2 2B, Gemma 2B e Phi-2 estão disponíveis no HuggingFace no formato safetensors. Como a API de inferência LLM oferece suporte apenas ao LoRA em camadas de atenção, especifique apenas camadas de atenção ao criar o LoraConfig da seguinte maneira:

# For Gemma
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

Para testes, há modelos LoRA ajustados e acessíveis publicamente que se encaixam na API LLM Inference disponível no HuggingFace. Por exemplo, monsterapi/gemma-2b-lora-maths-orca-200k para Gemma-2B e lole25/phi-2-sft-ultrachat-lora para Phi-2.

Depois de treinar no conjunto de dados preparado e salvar o modelo, você vai receber um arquivo adapter_model.safetensors contendo os pesos do modelo LoRA ajustados. O arquivo safetensors é o ponto de verificação da LoRA usado na conversão do modelo.

Na próxima etapa, você precisa converter os pesos do modelo em um Flatbuffer do TensorFlow Lite usando o pacote MediaPipe Python. O ConversionConfig precisa especificar as opções de modelo básico e outras opções de LoRa. Como a API só oferece suporte à inferência LoRA com GPU, o back-end precisa ser definido como 'gpu'.

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

O conversor vai gerar dois arquivos flatbuffer do TFLite, um para o modelo base e outro para o modelo LoRA.

Inferência de modelo LoRA

A API de inferência de LLM da Web, do Android e do iOS foi atualizada para oferecer suporte à inferência de modelos da LoRA.

O Android oferece suporte a LoRA estático durante a inicialização. Para carregar um modelo LoRA, os usuários especificam o caminho do modelo LoRA e o LLM de base.

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPath('<path to base model>')
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .setLoraPath('<path to LoRA model>')
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

Para executar a inferência do LLM com o LoRA, use os mesmos métodos generateResponse() ou generateResponseAsync() do modelo base.