Executar o Gemma usando o PyTorch

Ver em ai.google.dev Executar no Google Colab Executar no Kaggle Abrir na Vertex AI Ver código-fonte no GitHub

Este guia mostra como executar o Gemma usando o framework PyTorch, incluindo como usar dados de imagem para solicitar modelos da versão 3 e mais recentes do Gemma. Para mais detalhes sobre a implementação do Gemma PyTorch, consulte o README (em inglês) do repositório do projeto.

Configuração

As seções a seguir explicam como configurar seu ambiente de desenvolvimento, incluindo como acessar os modelos da Gemma para fazer o download do Kaggle, definir variáveis de autenticação, instalar dependências e importar pacotes.

Requisitos do sistema

Essa biblioteca do Gemma Pytorch exige processadores de GPU ou TPU para executar o modelo do Gemma. O ambiente de execução padrão do Python na CPU do Colab e o ambiente de execução do Python na GPU T4 são suficientes para executar modelos de tamanho Gemma 1B, 2B e 4B. Para casos de uso avançados de outras GPUs ou TPUs, consulte o README no repositório do Gemma PyTorch.

Acessar a Gemma no Kaggle

Para concluir este tutorial, primeiro siga as instruções de configuração em Configuração da Gemma, que mostram como fazer o seguinte:

  • Acesse a Gemma no Kaggle (link em inglês).
  • Selecione um ambiente de execução do Colab com recursos suficientes para executar o modelo do Gemma.
  • Gere e configure um nome de usuário e uma chave de API do Kaggle.

Depois de concluir a configuração do Gemma, passe para a próxima seção, em que você vai definir variáveis de ambiente para seu ambiente do Colab.

Defina as variáveis de ambiente

Defina as variáveis de ambiente para KAGGLE_USERNAME e KAGGLE_KEY. Quando as mensagens "Conceder acesso?" aparecerem, concorde em fornecer acesso ao secret.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Instalar dependências

pip install -q -U torch immutabledict sentencepiece

Baixar pesos do modelo

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT.split('-')[0]
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

Defina os caminhos do tokenizador e do checkpoint para o modelo.

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Configurar o ambiente de execução

As seções a seguir explicam como preparar um ambiente do PyTorch para executar a Gemma.

Preparar o ambiente de execução do PyTorch

Prepare o ambiente de execução do modelo PyTorch clonando o repositório Gemma Pytorch.

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

Definir a configuração do modelo

Antes de executar o modelo, é preciso definir alguns parâmetros de configuração, incluindo a variante do Gemma, o tokenizador e o nível de quantização.

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

Configurar o contexto do dispositivo

O código a seguir configura o contexto do dispositivo para executar o modelo:

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

Instanciar e carregar o modelo

Carregue o modelo com os pesos dele para se preparar para executar solicitações.

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

Executar inferência

Confira abaixo exemplos de geração no modo de chat e com várias solicitações.

Os modelos Gemma ajustados com instruções foram treinados com um formatador específico que anota exemplos de ajuste de instruções com informações extras durante o treinamento e a inferência. As anotações (1) indicam funções em uma conversa e (2) delineiam turnos em uma conversa.

Os tokens de anotação relevantes são:

  • user: vez do usuário
  • model: turno do modelo
  • <start_of_turn>: início da vez de falar
  • <start_of_image>: tag para entrada de dados de imagem
  • <end_of_turn><eos>: fim da vez da conversa

Para mais informações, leia sobre a formatação de comandos para modelos Gemma ajustados por instrução aqui.

Gerar texto com texto

Confira abaixo um exemplo de código que demonstra como formatar um comando para um modelo da Gemma ajustado por instruções usando modelos de chat do usuário e do modelo em uma conversa de várias rodadas.

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

Gerar texto com imagens

Com a versão 3 e mais recentes do Gemma, você pode usar imagens com seu comando. O exemplo a seguir mostra como incluir dados visuais no comando.

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image = read_image(
    'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
)

print(model.generate(
    [
        [
            '<start_of_turn>user\n',
            image,
            'What animal is in this image?<end_of_turn>\n',
            '<start_of_turn>model\n'
        ]
    ],
    device=device,
    output_len=256,
))

Saiba mais

Agora que você aprendeu a usar o Gemma no Pytorch, confira as muitas outras coisas que ele pode fazer em ai.google.dev/gemma.

Confira também estes outros recursos relacionados: