Gemma in PyTorch

View on ai.google.dev Run in Google Colab View source on GitHub

This is a quick demo of running Gemma inference in PyTorch. For more details, please check out the Github repo of the official PyTorch implementation here.

Note that:

  • The free Colab CPU Python runtime and T4 GPU Python runtime are sufficient for running the Gemma 2B models and 7B int8 quantized models.
  • For advanced use cases for other GPUs or TPU, please refer to README.md in the official repo.

Kaggle access

To login to Kaggle, you can either store your kaggle.json credentials file at ~/.kaggle/kaggle.json or run the following in a Colab environment. See the kagglehub package documentation for more details.

import kagglehub

kagglehub.login()

Install dependencies

pip install -q -U torch immutabledict sentencepiece

Download model weights

# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'
import os

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Download the model implementation

# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git
import sys

sys.path.append('gemma_pytorch')
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM

Setup the model

import torch

# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

Run inference

Below are examples for generating in chat mode and generating with multiple requests.

The instruction-tuned Gemma models were trained with a specific formatter that annotates instruction tuning examples with extra information, both during training and inference. The annotations (1) indicate roles in a conversation, and (2) delineate turns in a conversation. Below we show a sample code snippet for formatting the model prompt using the user and model chat templates in a multi-turn conversation. The relevant tokens are:

  • user: user turn
  • model: model turn
  • <start_of_turn>: beginning of dialogue turn
  • <end_of_turn>: end of dialogue turn

Read about the Gemma formatting for instruction tuning and system instructions here.

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn>
<start_of_turn>model
California.<end_of_turn>
<start_of_turn>user
What can I do in California?<end_of_turn>
<start_of_turn>model
"* **Visit the Golden Gate Bridge and Alcatraz Island in San Francisco.**\n* **Head to Yosemite National Park and marvel at nature's beauty.**\n* **Explore the bustling metropolis of Los Angeles.**\n* **Relax on the pristine beaches of Santa Monica or Malibu.**\n* **Go whale watching in Monterey Bay.**\n* **Discover the charming coastal towns of Monterey Bay and Carmel-by-the-Sea.**\n* **Visit Disneyland and Disney California Adventure in Anaheim.**\n*"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=60,
)
['\n\nThe fingers dance on the keys,\nA symphony of thoughts and dreams.\nThe mind, a canvas yet uncouth,\nScribbling its secrets in the night.\n\nThe ink, a whispered voice from deep,\nA language ancient, never to sleep.\nEach stroke an echo of']

Learn more

Now that you have learned how to use Gemma in Pytorch, you can explore the many other things that Gemma can do in ai.google.dev/gemma. See also these other related resources: