|
|
Executar no Google Colab
|
|
Veja o código-fonte no GitHub
|
Este tutorial demonstra como realizar a amostragem/inferência básica com o modelo de instrução 2B RecurrentGemma usando a biblioteca recurrentgemma do Google DeepMind que foi escrita com JAX (uma biblioteca de computação numérica de alto desempenho), Flax (a biblioteca de rede neural baseada em JAX), Orbax (uma biblioteca de token JAX1 (uma biblioteca baseada em token JAX1) e {1token1} de token de treinamento.SentencePiece Embora o Flax não seja usado diretamente neste notebook, ele foi usado para criar o Gemma e o RecurrentGemma (o modelo Griffin).
Esse notebook pode ser executado no Google Colab com a GPU T4 (acesse Editar > Configurações do notebook > em Acelerador de hardware, selecione GPU T4).
Configuração
As seções a seguir explicam as etapas de preparação de um notebook para usar um modelo RecurrentGemma, incluindo acesso ao modelo, obtenção de uma chave de API e configuração do ambiente de execução do notebook
Configure o acesso do Kaggle para o Gemma
Para concluir este tutorial, primeiro siga as instruções de configuração semelhantes às configuração do Gemma, com algumas exceções:
- Acesse o RecurrentGemma (em vez do Gemma) em kaggle.com.
- Selecione um ambiente de execução do Colab com recursos suficientes para executar o modelo RecurrentGemma.
- Gere e configure um nome de usuário e uma chave de API do Kaggle.
Depois de concluir a configuração do RecurrentGemma, vá para a próxima seção, em que você definirá variáveis de ambiente para o ambiente do Colab.
Defina as variáveis de ambiente
Defina as variáveis de ambiente para KAGGLE_USERNAME e KAGGLE_KEY. Quando a pergunta "Conceder acesso?" for exibida concordam em fornecer acesso a secrets.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Instalar a biblioteca recurrentgemma
Este notebook se concentra no uso de uma GPU Colab sem custo financeiro. Para ativar a aceleração de hardware, clique em Editar > Configurações do notebook > Selecione GPU T4 > Clique em Salvar.
Em seguida, você precisa instalar a biblioteca recurrentgemma do Google DeepMind no github.com/google-deepmind/recurrentgemma. Se você receber um erro sobre "resolvedor de dependências do pip", geralmente ele poderá ser ignorado.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
Carregar e preparar o modelo RecurrentGemma
- Carregue o modelo RecurrentGemma com
kagglehub.model_download, que usa três argumentos:
handle: o identificador do modelo do Kagglepath: (string opcional) o caminho localforce_download(booleano opcional): força o download do modelo novamente.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- Verifique o local dos pesos do modelo e do tokenizador e, em seguida, defina as variáveis do caminho. O diretório tokenizador estará no diretório principal em que você fez o download do modelo, enquanto os pesos do modelo estarão em um subdiretório. Exemplo:
- O arquivo
tokenizer.modelestará em/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1. - O checkpoint do modelo estará em
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Realizar amostragem/inferência
- Carregue o checkpoint do modelo RecurrentGemma com o método
recurrentgemma.jax.load_parameters. O argumentoshardingdefinido como"single_device"carrega todos os parâmetros do modelo em um único dispositivo.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- Carregue o tokenizador do modelo RecurrentGemma, construído usando
sentencepiece.SentencePieceProcessor:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- Para carregar automaticamente a configuração correta do checkpoint do modelo RecurrentGemma, use
recurrentgemma.GriffinConfig.from_flax_params_or_variables. Em seguida, instancie o modelo Griffin comrecurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- Crie um
samplercomrecurrentgemma.jax.Samplersobre o checkpoint/peso do modelo RecurrentGemma e o tokenizador:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- Escreva um comando em
prompte faça inferências. Você pode ajustar ototal_generation_steps, que é o número de etapas realizadas ao gerar uma resposta. Este exemplo usa50para preservar a memória do host.
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
warnings.warn("Some donated buffers were not usable:"
Prompt:
# 5+9=?
Output:
# Answer: 14
# Explanation: 5 + 9 = 14.
Saiba mais
- Saiba mais sobre a biblioteca
recurrentgemmado Google DeepMind no GitHub (link em inglês), que contém docstrings de métodos e módulos usados neste tutorial, comorecurrentgemma.jax.load_parameters,recurrentgemma.jax.Griffinerecurrentgemma.jax.Sampler. - As bibliotecas a seguir têm seus próprios sites de documentação: core JAX, Flax e Orbax.
- Para ver a documentação do tokenizador/destokenizador do
sentencepiece, consulte o repositório do GitHub dosentencepiecedo Google (em inglês). - Para conferir a documentação do
kagglehub, confiraREADME.mdno repositório do GitHub dokagglehubdo Kaggle (links em inglês). - Saiba como usar modelos do Gemma com a Vertex AI do Google Cloud.
- Confira o filme RecurrentGemma: Transição de Transformadores para modelos de linguagem abertos eficientes do Google DeepMind.
- Leia Griffin: como misturar recorrências lineares fechadas com Documento sobre a atenção local para modelos de linguagem eficientes do GoogleDeepMind para saber mais sobre a arquitetura de modelo usada pela RecurrentGemma.
Executar no Google Colab
Veja o código-fonte no GitHub