Prompt with images and text using Gemma library

Using images for prompting Gemma models opens up a whole new range of possibilies for understanding your world and solving problems with visual data. Starting with Gemma 3 in 4B sizes and higher, you can use image data as part of your prompt to for richer context and to solve more complex tasks.

This tutorial shows you how to prompt Gemma with images using the Gemma library for JAX. Gemma library is a Python package built as an extension of JAX, letting you use the performance advantages of the JAX framework with dramatically less code.

Setup

To complete this tutorial, you'll first need to complete the setup instructions at Gemma setup. The Gemma setup instructions show you how to do the following:

  • Get access to Gemma on kaggle.com.
  • Select a Colab runtime with sufficient resources to run the Gemma model.
  • Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

Install libraries

Install the Gemma library.

pip install -q gemma

Set environment variables

Set environment variables for KAGGLE_USERNAME and KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Set the JAX environment to use the full GPU memory space.

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Import packages

Import the Gemma library and additional support libraries.

# Common imports
import os
import jax
import jax.numpy as jnp
import tensorflow_datasets as tfds

# Gemma imports
from gemma import gm

Configure a model

Select and configure a Gemma model for use, including a tokenizer, model architecture, and checkpoints. The Gemma libary supports all of Google's official releases of the model. You must use the Gemma3Tokenizer and a Gemma 3 or later model to be able to process images as part of your prompt.

To configure the model, run the following code:

tokenizer = gm.text.Gemma3Tokenizer()

model = gm.nn.Gemma3_4B()

params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)

Generate text with text

Start by prompting with text. The Gemma library provides a Sampler function for simple prompting.

sampler = gm.text.Sampler(
    model=model,
    params=params,
    tokenizer=tokenizer,
)

sampler.sample('Roses are red.', max_new_tokens=30)

Change the prompt and change the maximum number of tokens to generate different output.

Generate text with images

Once you have a text prompt working, you can add images to your prompt. Make sure you have configure a Gemma 3 or later model that is 4B or higher, and configured the Gemma3Tokenizer.

Load an image

Load an image from a data source or a local file. The following code shows how to load an image from a TensorFlow datasource:

ds = tfds.data_source('oxford_flowers102', split='train')
image = ds[0]['image']

# display the image
image

Prepare prompt with image data

When you prompt with image data, you include a specific tag <start_of_image>, to include the image with the text your prompt. You then encode the prompt with the image data using the tokenizer object to prepare to run it with the model.

prompt = """<start_of_turn>user
Describe the contents of this image.

<start_of_image>

<end_of_turn>
<start_of_turn>model
"""

If you want to prompt with more than one image, you must include a <start_of_image> tag for each image included in your prompt.

Run the prompt with image data

After you prepare your image data and the prompt with image tags, you can run the prompt and generate output. The following code shows how to use the Sampler function run the prompt:

sampler = gm.text.Sampler(
    model=model,
    params=params,
    tokenizer=tokenizer,
)

out = sampler.sample(prompt, images=image, max_new_tokens=500)
print(out)

Alternatively, you can use the gm.text.ChatSampler() function generate a response without requiring <start_of_turn> tags. For more details, see the Gemma library for JAX documentation.

Next steps

The Gemma library provides much more additional functionality. See these additional resources for more information:

The Gemma library for JAX provides additional functionality, including LoRA, Sharding, Quantization and more. For more details, see the Gemma library documentation. If you have any feedback, or have issues using Gemma library, submit them through the repository Issues interface.