Inference with Gemma using JAX and Flax

View on Run in Google Colab Open in Vertex AI View source on GitHub


Gemma is a family of lightweight, state-of-the-art open large language models, based on the Google DeepMind Gemini research and technology. This tutorial demonstrates how to perform basic sampling/inference with the Gemma 2B Instruct model using Google DeepMind's gemma library that was written with JAX (a high-performance numerical computing library), Flax (the JAX-based neural network library), Orbax (a JAX-based library for training utilities like checkpointing), and SentencePiece (a tokenizer/detokenizer library). Although Flax is not used directly in this notebook, Flax was used to create Gemma.

This notebook can run on Google Colab with free T4 GPU (go to Edit > Notebook settings > Under Hardware accelerator select T4 GPU).


1. Set up Kaggle access for Gemma

To complete this tutorial, you first need to follow the setup instructions at Gemma setup, which show you how to do the following:

  • Get access to Gemma on
  • Select a Colab runtime with sufficient resources to run the Gemma model.
  • Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

2. Set environment variables

Set environment variables for KAGGLE_USERNAME and KAGGLE_KEY. When prompted with the "Grant access?" messages, agree to provide secret access.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Install the gemma library

This notebook focuses on using a free Colab GPU. To enable hardware acceleration, click on Edit > Notebook settings > Select T4 GPU > Save.

Next, you need to install the Google DeepMind gemma library from If you get an error about "pip's dependency resolver", you can usually ignore it.

pip install -q git+

Load and prepare the Gemma model

  1. Load the Gemma model with kagglehub.model_download, which takes three arguments:
  • handle: The model handle from Kaggle
  • path: (Optional string) The local path
  • force_download: (Optional boolean) Forces to re-download the model
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. Check the location of the model weights and the tokenizer, then set the path variables. The tokenizer directory will be in the main directory where you downloaded the model, while the model weights will be in a sub-directory. For example:
  • The tokenizer.model file will be in /LOCAL/PATH/TO/gemma/flax/2b-it/2).
  • The model checkpoint will be in /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it).
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

Perform sampling/inference

  1. Load and format the Gemma model checkpoint with the gemma.params.load_and_format_params method:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. Load the Gemma tokenizer, constructed using sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
  1. To automatically load the correct configuration from the Gemma model checkpoint, use gemma.transformer.TransformerConfig. The cache_size argument is the number of time steps in the Gemma Transformer cache. Afterwards, instantiate the Gemma model as transformer with gemma.transformer.Transformer (which inherits from flax.linen.Module).
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(

transformer = transformer_lib.Transformer(transformer_config)
  1. Create a sampler with gemma.sampler.Sampler on top of the Gemma model checkpoint/weights and the tokenizer:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
  1. Write a prompt in input_batch and perform inference. You can tweak total_generation_steps (the number of steps performed when generating a response — this example uses 100 to preserve host memory).
prompt = [
    "\n# What is the meaning of life?",

reply = sampler(input_strings=prompt,

for input_string, out_string in zip(prompt, reply.text):

# What is the meaning of life?

The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. (Optional) Run this cell to free up memory if you have completed the notebook and want to try another prompt. Afterwards, you can instantiate the sampler again in step 3 and customize and run the prompt in step 4.
del sampler

Learn more