Ver em ai.google.dev | Executar no Google Colab | Abrir na Vertex AI | Veja o código-fonte no GitHub |
Este tutorial demonstra como realizar a amostragem/inferência básica com o modelo de instrução 2B RecurrentGemma usando a biblioteca recurrentgemma
do Google DeepMind que foi escrita com JAX (uma biblioteca de computação numérica de alto desempenho), Flax (a biblioteca de rede neural baseada em JAX), Orbax (uma biblioteca de token JAX1 (uma biblioteca baseada em token JAX1) e {1token1} de token de treinamento.SentencePiece Embora o Flax não seja usado diretamente neste notebook, ele foi usado para criar o Gemma e o RecurrentGemma (o modelo Griffin).
Esse notebook pode ser executado no Google Colab com a GPU T4 (acesse Editar > Configurações do notebook > em Acelerador de hardware, selecione GPU T4).
Configuração
As seções a seguir explicam as etapas de preparação de um notebook para usar um modelo RecurrentGemma, incluindo acesso ao modelo, obtenção de uma chave de API e configuração do ambiente de execução do notebook
Configure o acesso do Kaggle para o Gemma
Para concluir este tutorial, primeiro siga as instruções de configuração semelhantes às configuração do Gemma, com algumas exceções:
- Acesse o RecurrentGemma (em vez do Gemma) em kaggle.com.
- Selecione um ambiente de execução do Colab com recursos suficientes para executar o modelo RecurrentGemma.
- Gere e configure um nome de usuário e uma chave de API do Kaggle.
Depois de concluir a configuração do RecurrentGemma, vá para a próxima seção, em que você definirá variáveis de ambiente para o ambiente do Colab.
Defina as variáveis de ambiente
Defina as variáveis de ambiente para KAGGLE_USERNAME
e KAGGLE_KEY
. Quando a pergunta "Conceder acesso?" for exibida concordam em fornecer acesso a secrets.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Instalar a biblioteca recurrentgemma
Este notebook se concentra no uso de uma GPU Colab sem custo financeiro. Para ativar a aceleração de hardware, clique em Editar > Configurações do notebook > Selecione GPU T4 > Clique em Salvar.
Em seguida, você precisa instalar a biblioteca recurrentgemma
do Google DeepMind no github.com/google-deepmind/recurrentgemma
. Se você receber um erro sobre "resolvedor de dependências do pip", geralmente ele poderá ser ignorado.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
Carregar e preparar o modelo RecurrentGemma
- Carregue o modelo RecurrentGemma com
kagglehub.model_download
, que usa três argumentos:
handle
: o identificador do modelo do Kagglepath
: (string opcional) o caminho localforce_download
(booleano opcional): força o download do modelo novamente.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- Verifique o local dos pesos do modelo e do tokenizador e, em seguida, defina as variáveis do caminho. O diretório tokenizador estará no diretório principal em que você fez o download do modelo, enquanto os pesos do modelo estarão em um subdiretório. Exemplo:
- O arquivo
tokenizer.model
estará em/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
. - O checkpoint do modelo estará em
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Realizar amostragem/inferência
- Carregue o checkpoint do modelo RecurrentGemma com o método
recurrentgemma.jax.load_parameters
. O argumentosharding
definido como"single_device"
carrega todos os parâmetros do modelo em um único dispositivo.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- Carregue o tokenizador do modelo RecurrentGemma, construído usando
sentencepiece.SentencePieceProcessor
:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- Para carregar automaticamente a configuração correta do checkpoint do modelo RecurrentGemma, use
recurrentgemma.GriffinConfig.from_flax_params_or_variables
. Em seguida, instancie o modelo Griffin comrecurrentgemma.jax.Griffin
.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- Crie um
sampler
comrecurrentgemma.jax.Sampler
sobre o checkpoint/peso do modelo RecurrentGemma e o tokenizador:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- Escreva um comando em
prompt
e faça inferências. Você pode ajustar ototal_generation_steps
, que é o número de etapas realizadas ao gerar uma resposta. Este exemplo usa50
para preservar a memória do host.
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
Saiba mais
- Saiba mais sobre a biblioteca
recurrentgemma
do Google DeepMind no GitHub (link em inglês), que contém docstrings de métodos e módulos usados neste tutorial, comorecurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
erecurrentgemma.jax.Sampler
. - As bibliotecas a seguir têm seus próprios sites de documentação: core JAX, Flax e Orbax.
- Para ver a documentação do tokenizador/destokenizador do
sentencepiece
, consulte o repositório do GitHub dosentencepiece
do Google (em inglês). - Para conferir a documentação do
kagglehub
, confiraREADME.md
no repositório do GitHub dokagglehub
do Kaggle (links em inglês). - Saiba como usar modelos do Gemma com a Vertex AI do Google Cloud.
- Confira o filme RecurrentGemma: Transição de Transformadores para modelos de linguagem abertos eficientes do Google DeepMind.
- Leia Griffin: como misturar recorrências lineares fechadas com Documento sobre a atenção local para modelos de linguagem eficientes do GoogleDeepMind para saber mais sobre a arquitetura de modelo usada pela RecurrentGemma.