Guia de inferência de LLM para Android

A API LLM Inference permite executar modelos de linguagem grandes (LLMs) totalmente no dispositivo para aplicativos Android, que podem ser usados para executar várias tarefas, como gerar texto, extrair informações em formato de linguagem natural e resumir documentos. A tarefa oferece suporte integrado a vários modelos de linguagem grande de texto para texto para que você possa aplicar os modelos de IA generativa mais recentes no dispositivo aos seus apps Android.

A tarefa oferece suporte à Gemma 2B, parte de uma família de modelos abertos leves e de última geração criados com a mesma pesquisa e tecnologia usadas para criar os modelos do Gemini. Ele também é compatível com os seguintes modelos externos: Phi-2, Falcon-RW-1B e StableLM-3B, além de todos os modelos exportados pelo AI Edge.

Para mais informações sobre os recursos, modelos e opções de configuração dessa tarefa, consulte a Visão geral.

Exemplo de código

Este guia se refere a um exemplo de um app básico de geração de texto para Android. Você pode usar o app como ponto de partida para seu próprio app Android ou consultá-lo ao modificar um app já existente. O código de exemplo está hospedado no GitHub (em inglês).

Fazer o download do código

As instruções a seguir mostram como criar uma cópia local do código de exemplo usando a ferramenta de linha de comando git.

Para fazer o download do código de exemplo:

  1. Clone o repositório git usando o seguinte comando:
    git clone https://github.com/google-ai-edge/mediapipe-samples
    
  2. Como opção, configure sua instância git para usar a finalização da compra esparsa para que você tenha apenas os arquivos do app de exemplo da API LLM Inference:
    cd mediapipe
    git sparse-checkout init --cone
    git sparse-checkout set examples/llm_inference/android
    

Depois de criar uma versão local do código de exemplo, você poderá importar o projeto para o Android Studio e executar o app. Para conferir instruções, consulte o Guia de configuração para Android.

Configuração

Esta seção descreve as principais etapas para configurar seu ambiente de desenvolvimento e projetos de código especificamente para usar a API LLM Inference. Para ter informações gerais sobre como configurar seu ambiente de desenvolvimento para usar tarefas do MediaPipe, incluindo requisitos de versão da plataforma, consulte o Guia de configuração para Android.

Dependências

A API LLM Inference usa a biblioteca com.google.mediapipe:tasks-genai. Adicione esta dependência ao arquivo build.gradle do app Android:

dependencies {
    implementation 'com.google.mediapipe:tasks-genai:0.10.14'
}

Modelo

A API MediaPipe LLM Inference requer um modelo de linguagem de texto para texto treinado que seja compatível com essa tarefa. Depois de fazer o download de um modelo, instale as dependências necessárias e envie o modelo ao dispositivo Android. Se você estiver usando um modelo diferente do Gemma, será preciso converter o modelo para um formato compatível com o MediaPipe.

Para mais informações sobre os modelos treinados disponíveis para a API LLM Inference, consulte a seção Modelos, uma visão geral da tarefa.

Fazer o download de um modelo

Antes de inicializar a API LLM Inference, faça o download de um dos modelos compatíveis e armazene o arquivo no diretório do projeto:

  • Gemma 2B: parte de uma família de modelos abertos leves e de última geração, criados a partir da mesma pesquisa e tecnologia usadas para criar os modelos do Gemini. Adequado para várias tarefas de geração de texto, incluindo respostas a perguntas, resumo e raciocínio.
  • Phi-2: modelo de transformador de 2, 7 bilhões de parâmetros, mais adequado para formatos de perguntas e respostas, chat e código.
  • Falcon-RW-1B: modelo somente decodificador causal de um parâmetro treinado com 350 bilhões de tokens do RefinedWeb.
  • StableLM-3B: modelo de linguagem de três bilhões de decodificadores de parâmetros pré-treinado em 1 trilhão de tokens de diversos conjuntos de dados em inglês e de código.

Como alternativa, é possível usar modelos mapeados e exportados por meio do AI Edge Troch.

Recomendamos o uso da Gemma 2B, que está disponível nos Modelos Kaggle e tem um formato compatível com a API LLM Inference. Se você usar outro LLM, será necessário converter o modelo em um formato compatível com o MediaPipe. Para saber mais, consulte o site da Gemma 2B. Para mais informações sobre os outros modelos disponíveis, consulte a seção Modelos de visão geral da tarefa.

Converter modelo para o formato MediaPipe

Conversão de modelo nativo

Se você estiver usando um LLM externo (Phi-2, Falcon ou StableLM) ou uma versão do Gemma que não seja Kaggle, utilize nossos scripts de conversão para formatar o modelo para que seja compatível com o MediaPipe.

O processo de conversão de modelo requer o pacote MediaPipe PyPI. O script de conversão está disponível em todos os pacotes do MediaPipe após 0.10.11.

Instale e importe as dependências com o seguinte código:

$ python3 -m pip install mediapipe

Use a biblioteca genai.converter para converter o modelo:

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  input_ckpt=INPUT_CKPT,
  ckpt_format=CKPT_FORMAT,
  model_type=MODEL_TYPE,
  backend=BACKEND,
  output_dir=OUTPUT_DIR,
  combine_file_only=False,
  vocab_model_file=VOCAB_MODEL_FILE,
  output_tflite_file=OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

Para converter o modelo LoRA, o ConversionConfig precisa especificar as opções do modelo base, bem como outras opções da LoRA. Como a API só oferece suporte à inferência de LoRA com GPU, o back-end precisa ser definido como 'gpu'.

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

O conversor gerará dois arquivos de flatbuffer TFLite, um para o modelo base e outro para o modelo LoRA.

Parâmetro Descrição Valores aceitos
input_ckpt O caminho para o arquivo model.safetensors ou pytorch.bin. Às vezes, o formato dos safetensor do modelo é fragmentado em vários arquivos, por exemplo, model-00001-of-00003.safetensors e model-00001-of-00003.safetensors. É possível especificar um padrão de arquivo, como model*.safetensors. PATH
ckpt_format O formato de arquivo do modelo. {"safetensors", "pytorch"}
model_type O LLM sendo convertido. {"PHI_2", "FALCON_RW_1B", "STABLELM_4E1T_3B", "GEMMA_2B"}
backend O processador (delegado) usado para executar o modelo. {"cpu", "gpu"}
output_dir O caminho para o diretório de saída que hospeda os arquivos de peso por camada. PATH
output_tflite_file O caminho para o arquivo de saída. Por exemplo, "model_cpu.bin" ou "model_gpu.bin". Esse arquivo é compatível apenas com a API LLM Inference e não pode ser usado como um arquivo "tflite" geral. PATH
vocab_model_file O caminho para o diretório que armazena os arquivos tokenizer.json e tokenizer_config.json. Para o Gemma, aponte para o único arquivo tokenizer.model. PATH
lora_ckpt O caminho para o arquivo de safetensors da LoRA que armazena o peso do adaptador da LoRA. PATH
lora_rank Um número inteiro que representa a classificação de ckpt da LoRA. Obrigatório para converter os pesos lora. Se esse valor não for informado, o conversor vai presumir que não há pesos da LoRA. Observação: apenas o back-end da GPU é compatível com LoRA. Número inteiro
lora_output_tflite_file Nome de arquivo tflite de saída para os pesos da LoRA. PATH

Conversão de modelo do AI Edge

Se você estiver usando um LLM mapeado para um modelo do TFLite pelo AI Edge, use nosso script de agrupamento para criar um pacote de tarefas. O processo de agrupamento empacota o modelo mapeado com outros metadados (por exemplo, tokenizadores) necessários para executar a inferência completa.

O processo de empacotamento de modelos exige o pacote PyPI do MediaPipe. O script de conversão está disponível em todos os pacotes do MediaPipe após 0.10.14.

Instale e importe as dependências com o seguinte código:

$ python3 -m pip install mediapipe

Use a biblioteca genai.bundler para agrupar o modelo:

import mediapipe as mp
from mediapipe.tasks.python.genai import bundler

config = bundler.BundleConfig(
    tflite_model=TFLITE_MODEL,
    tokenizer_model=TOKENIZER_MODEL,
    start_token=START_TOKEN,
    stop_tokens=STOP_TOKENS,
    output_filename=OUTPUT_FILENAME,
    enable_bytes_to_unicode_mapping=ENABLE_BYTES_TO_UNICODE_MAPPING,
)
bundler.create_bundle(config)
Parâmetro Descrição Valores aceitos
tflite_model O caminho para o modelo do TFLite exportado pelo AI Edge. PATH
tokenizer_model O caminho para o modelo do tokenizador do SentencePiece. PATH
start_token Token inicial específico do modelo. O token inicial precisa estar presente no modelo do tokenizador fornecido. STRING
stop_tokens Tokens de parada específicos do modelo. Os tokens de parada precisam estar presentes no modelo do tokenizador fornecido. LISTA[STRING]
output_filename O nome do arquivo do pacote de tarefas de saída. PATH

Enviar o modelo para o dispositivo

Envie o conteúdo da pasta output_path para o dispositivo Android.

$ adb shell rm -r /data/local/tmp/llm/ # Remove any previously loaded models
$ adb shell mkdir -p /data/local/tmp/llm/
$ adb push output_path /data/local/tmp/llm/model_version.bin

Criar a tarefa

A API MediaPipe LLM Inference usa a função createFromOptions() para configurar a tarefa. A função createFromOptions() aceita valores para as opções de configuração. Para mais informações sobre as opções de configuração, consulte Opções de configuração.

O código a seguir inicializa a tarefa usando opções de configuração básicas:

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPATH('/data/local/.../')
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

Opções de configuração

Use as seguintes opções para configurar um app Android:

Nome da opção Descrição Intervalo de valor Valor padrão
modelPath O caminho para onde o modelo está armazenado no diretório do projeto. PATH N/A
maxTokens O número máximo de tokens (de entrada + saída) que o modelo gerencia. Número inteiro 512
topK O número de tokens que o modelo considera em cada etapa da geração. Limita as previsões aos k principais tokens mais prováveis. Ao definir topK, você também precisa definir um valor para randomSeed. Número inteiro 40
temperature A quantidade de aleatoriedade gerada durante a geração. Uma temperatura mais alta resulta em mais criatividade no texto gerado, enquanto uma temperatura mais baixa produz uma geração mais previsível. Ao definir temperature, você também precisa definir um valor para randomSeed. Ponto flutuante 0,8
randomSeed A semente aleatória usada durante a geração de texto. Número inteiro 0
loraPath O caminho absoluto para o modelo LoRA localmente no dispositivo. Observação: isso só é compatível com modelos de GPU. PATH N/A
resultListener Define o listener de resultado para receber os resultados de forma assíncrona. Aplicável somente ao usar o método de geração assíncrona. N/A N/A
errorListener Define um listener de erro opcional. N/A N/A

preparar os dados

A API LLM Inference aceita as seguintes entradas:

  • prompt (string): uma pergunta ou comando.
val inputPrompt = "Compose an email to remind Brett of lunch plans at noon on Saturday."

Executar a tarefa

Use o método generateResponse() para gerar uma resposta de texto para o texto de entrada fornecido na seção anterior (inputPrompt). Isso produz uma única resposta gerada.

val result = llmInference.generateResponse(inputPrompt)
logger.atInfo().log("result: $result")

Para transmitir a resposta, use o método generateResponseAsync().

val options = LlmInference.LlmInferenceOptions.builder()
  ...
  .setResultListener { partialResult, done ->
    logger.atInfo().log("partial result: $partialResult")
  }
  .build()

llmInference.generateResponseAsync(inputPrompt)

Gerenciar e mostrar resultados

A API LLM Inference retorna um LlmInferenceResult, que inclui o texto de resposta gerado.

Here's a draft you can use:

Subject: Lunch on Saturday Reminder

Hi Brett,

Just a quick reminder about our lunch plans this Saturday at noon.
Let me know if that still works for you.

Looking forward to it!

Best,
[Your Name]

Personalização do modelo LoRA

A API de inferência LLM da Mediapipe pode ser configurada para oferecer suporte à adaptação de baixa classificação (LoRA, na sigla em inglês) para modelos de linguagem grandes. Utilizando modelos LoRA ajustados, os desenvolvedores podem personalizar o comportamento dos LLMs com um processo de treinamento econômico.

O suporte da LoRA à API LLM Inference funciona com modelos Gemma-2B e Phi-2 para o back-end da GPU, com pesos de LoRA aplicáveis apenas a camadas de atenção. Essa implementação inicial serve como uma API experimental para desenvolvimentos futuros, com planos de oferecer suporte a mais modelos e vários tipos de camadas nas próximas atualizações.

Preparar modelos LoRA

Siga as instruções em HuggingFace para treinar um modelo LoRA ajustado no seu próprio conjunto de dados com tipos de modelo compatíveis, Gemma-2B ou Phi-2. Os modelos Gemma-2B e Phi-2 estão disponíveis no HuggingFace no formato safetensor. Como a API LLM Inference só oferece suporte a LoRA em camadas de atenção, especifique apenas elas ao criar a LoraConfig da seguinte maneira:

# For Gemma-2B
from peft import LoraConfig
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# For Phi-2
config = LoraConfig(
    r=LORA_RANK,
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)

Para testes, há modelos LoRA ajustados publicamente acessíveis que se encaixam na API LLM Inference disponível no HuggingFace. Por exemplo, monsterapi/gemma-2b-lora-maths-orca-200k para Gemma-2B e lole25/phi-2-sft-ultrachat-lora para Phi-2.

Depois de treinar no conjunto de dados preparado e salvar o modelo, você recebe um arquivo adapter_model.safetensors contendo os pesos do modelo LoRA ajustado. O arquivo safetensors é o checkpoint da LoRA usado na conversão do modelo.

Na próxima etapa, você vai precisar converter os pesos do modelo em um Flatbuffer do TensorFlow Lite usando o pacote MediaPipe Python. O ConversionConfig precisa especificar as opções do modelo base e outras opções da LoRA. Como a API só oferece suporte à inferência LoRA com a GPU, o back-end precisa ser definido como 'gpu'.

import mediapipe as mp
from mediapipe.tasks.python.genai import converter

config = converter.ConversionConfig(
  # Other params related to base model
  ...
  # Must use gpu backend for LoRA conversion
  backend='gpu',
  # LoRA related params
  lora_ckpt=LORA_CKPT,
  lora_rank=LORA_RANK,
  lora_output_tflite_file=LORA_OUTPUT_TFLITE_FILE,
)

converter.convert_checkpoint(config)

O conversor gerará dois arquivos de flatbuffer TFLite, um para o modelo base e outro para o modelo LoRA.

Inferência de modelo LoRA

A API LLM Inference para Web, Android e iOS foi atualizada para oferecer suporte à inferência de modelo LoRA. A Web oferece suporte à LoRA dinâmica, que pode alternar diferentes modelos LoRA durante o tempo de execução. O Android e o iOS são compatíveis com a LoRA estática, que usa os mesmos pesos da LoRA durante o ciclo de vida da tarefa.

O Android oferece suporte à LoRA estática durante a inicialização. Para carregar um modelo LoRA, os usuários especificam o caminho do modelo LoRA e o LLM base.

// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
        .setModelPath('<path to base model>')
        .setMaxTokens(1000)
        .setTopK(40)
        .setTemperature(0.8)
        .setRandomSeed(101)
        .setLoraPath('<path to LoRA model>')
        .build()

// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)

Para executar a inferência LLM com o LoRA, use os mesmos métodos generateResponse() ou generateResponseAsync() que o modelo base.