Inferência com RecurrentGemma usando JAX e Flax

Ver em ai.google.dev Executar no Google Colab Abrir na Vertex AI Veja o código-fonte no GitHub

Este tutorial demonstra como realizar a amostragem/inferência básica com o modelo de instrução 2B RecurrentGemma usando a biblioteca recurrentgemma do Google DeepMind que foi escrita com JAX (uma biblioteca de computação numérica de alto desempenho), Flax (a biblioteca de rede neural baseada em JAX), Orbax (uma biblioteca de token JAX1 (uma biblioteca baseada em token JAX1) e {1token1} de token de treinamento.SentencePiece Embora o Flax não seja usado diretamente neste notebook, ele foi usado para criar o Gemma e o RecurrentGemma (o modelo Griffin).

Esse notebook pode ser executado no Google Colab com a GPU T4 (acesse Editar > Configurações do notebook > em Acelerador de hardware, selecione GPU T4).

Configuração

As seções a seguir explicam as etapas de preparação de um notebook para usar um modelo RecurrentGemma, incluindo acesso ao modelo, obtenção de uma chave de API e configuração do ambiente de execução do notebook

Configure o acesso do Kaggle para o Gemma

Para concluir este tutorial, primeiro siga as instruções de configuração semelhantes às configuração do Gemma, com algumas exceções:

  • Acesse o RecurrentGemma (em vez do Gemma) em kaggle.com.
  • Selecione um ambiente de execução do Colab com recursos suficientes para executar o modelo RecurrentGemma.
  • Gere e configure um nome de usuário e uma chave de API do Kaggle.

Depois de concluir a configuração do RecurrentGemma, vá para a próxima seção, em que você definirá variáveis de ambiente para o ambiente do Colab.

Defina as variáveis de ambiente

Defina as variáveis de ambiente para KAGGLE_USERNAME e KAGGLE_KEY. Quando a pergunta "Conceder acesso?" for exibida concordam em fornecer acesso a secrets.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Instalar a biblioteca recurrentgemma

Este notebook se concentra no uso de uma GPU Colab sem custo financeiro. Para ativar a aceleração de hardware, clique em Editar > Configurações do notebook > Selecione GPU T4 > Clique em Salvar.

Em seguida, você precisa instalar a biblioteca recurrentgemma do Google DeepMind no github.com/google-deepmind/recurrentgemma. Se você receber um erro sobre "resolvedor de dependências do pip", geralmente ele poderá ser ignorado.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

Carregar e preparar o modelo RecurrentGemma

  1. Carregue o modelo RecurrentGemma com kagglehub.model_download, que usa três argumentos:
  • handle: o identificador do modelo do Kaggle
  • path: (string opcional) o caminho local
  • force_download (booleano opcional): força o download do modelo novamente.
.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. Verifique o local dos pesos do modelo e do tokenizador e, em seguida, defina as variáveis do caminho. O diretório tokenizador estará no diretório principal em que você fez o download do modelo, enquanto os pesos do modelo estarão em um subdiretório. Exemplo:
  • O arquivo tokenizer.model estará em /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1.
  • O checkpoint do modelo estará em /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Realizar amostragem/inferência

  1. Carregue o checkpoint do modelo RecurrentGemma com o método recurrentgemma.jax.load_parameters. O argumento sharding definido como "single_device" carrega todos os parâmetros do modelo em um único dispositivo.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. Carregue o tokenizador do modelo RecurrentGemma, construído usando sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Para carregar automaticamente a configuração correta do checkpoint do modelo RecurrentGemma, use recurrentgemma.GriffinConfig.from_flax_params_or_variables. Em seguida, instancie o modelo Griffin com recurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. Crie um sampler com recurrentgemma.jax.Sampler sobre o checkpoint/peso do modelo RecurrentGemma e o tokenizador:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. Escreva um comando em prompt e faça inferências. Você pode ajustar o total_generation_steps, que é o número de etapas realizadas ao gerar uma resposta. Este exemplo usa 50 para preservar a memória do host.
.
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

Saiba mais