Inferência com o Gemma usando JAX e Flax

Ver em ai.google.dev Executar no Google Colab Abrir na Vertex AI Consulte o código-fonte no GitHub

Visão geral

A Gemma é uma família de modelos de linguagem grandes, leves e modernos, com base na pesquisa e tecnologia do Google DeepMind Gemini. Este tutorial demonstra como realizar amostragem/inferência básica com o modelo Gemma 2B Instruct usando a biblioteca gemma do Google DeepMind que foi escrita com JAX (uma biblioteca de computação numérica de alto desempenho), Flax (a biblioteca de rede neural baseada em JAX), Orbax (uma biblioteca baseada em JAX para utilitários de treinamento como token de checkpoint/tokena{/1deizer{/1deizer}).SentencePiece Embora o Flax não seja usado diretamente neste notebook, ele foi usado para criar o Gemma.

Este notebook pode ser executado no Google Colab com GPU T4 sem custo financeiro (acesse Editar > Configurações do notebook > em Acelerador de hardware selecione GPU T4).

Configuração

1. Configurar o acesso do Kaggle para o Gemma

Para concluir este tutorial, primeiro você precisa seguir as instruções de configuração em Gemma setup, que mostram como fazer o seguinte:

  • Acesse o Gemma em kaggle.com.
  • Selecione um ambiente de execução do Colab com recursos suficientes para executar o modelo Gemma.
  • Gere e configure um nome de usuário e uma chave de API do Kaggle.

Depois de concluir a configuração do Gemma, passe para a próxima seção, em que você definirá variáveis para seu ambiente do Colab.

2. Defina as variáveis de ambiente

Defina as variáveis de ambiente para KAGGLE_USERNAME e KAGGLE_KEY. Quando a mensagem "Permitir acesso?" for exibida, concorde em fornecer acesso secreto.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Instalar a biblioteca gemma

Este notebook se concentra no uso de uma GPU Colab sem custo financeiro. Para ativar a aceleração de hardware, clique em Editar > Configurações de notebook > Selecione GPU T4 > Salvar.

Em seguida, você precisa instalar a biblioteca gemma do Google DeepMind de github.com/google-deepmind/gemma. Geralmente, se você receber um erro sobre o "resolvedor de dependências do pip", ele poderá ser ignorado.

pip install -q git+https://github.com/google-deepmind/gemma.git

Carregar e preparar o modelo do Gemma

  1. Carregue o modelo Gemma com kagglehub.model_download, que usa três argumentos:
  • handle: o identificador de modelo do Kaggle
  • path: (string opcional) o caminho local
  • force_download: (booleano opcional) força o novo download do modelo
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. Verifique o local dos pesos do modelo e do tokenizador e, em seguida, defina as variáveis de caminho. O diretório do tokenizador estará no diretório principal onde você fez o download do modelo, enquanto os pesos do modelo estarão em um subdiretório. Exemplo:
  • O arquivo tokenizer.model estará em /LOCAL/PATH/TO/gemma/flax/2b-it/2).
  • O checkpoint do modelo estará em /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

Realizar amostragem/inferência

  1. Carregue e formate o checkpoint do modelo Gemma com o método gemma.params.load_and_format_params:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. Carregue o tokenizador Gemma, construído usando sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Para carregar automaticamente a configuração correta do checkpoint do modelo Gemma, use gemma.transformer.TransformerConfig. O argumento cache_size é o número de etapas de tempo no cache Transformer do Gemma. Em seguida, instancie o modelo Gemma como transformer com gemma.transformer.Transformer (herdado de flax.linen.Module).
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. Crie um sampler com gemma.sampler.Sampler sobre o checkpoint/pesos do modelo Gemma e o tokenizador:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. Escreva um comando em input_batch e faça inferências. É possível ajustar total_generation_steps, o número de etapas realizadas ao gerar uma resposta. Este exemplo usa 100 para preservar a memória do host.
prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. (Opcional) Execute esta célula para liberar memória se você tiver concluído o notebook e quiser testar outro comando. Em seguida, você pode instanciar o sampler novamente na etapa 3 e personalizar e executar o comando na etapa 4.
del sampler

Saiba mais