Inference with Keras

View on ai.google.dev Run in Google Colab View source on GitHub

When your AI model produces a conclusion or a prediction, it goes through a process called inference. This tutorial goes over how to use PaliGemma with Keras to set up a simple model that can infer information about supplied images and answer questions about them.

What's in this notebook

This notebook uses PaliGemma with Keras and shows you how to:

  • Install Keras and the required dependencies
  • Download PaliGemmaCausalLM, a pre-trained PaliGemma variant for causal visual language modeling, and use it to create a model
  • Test the model's ability to infer information about supplied images

Before you begin

Before going through this notebook, you should be familiar with Python code, as well as how large language models (LLMs) are trained. You don't need to be familiar with Keras, but basic knowledge about Keras is helpful when reading through the example code.

Setup

The following sections explain the preliminary steps for getting a notebook to use a PaliGemma model, including model access, getting an API key, and configuring the notebook runtime.

Get access to PaliGemma

Before using PaliGemma for the first time, you must request access to the model through Kaggle by completing the following steps:

  1. Log in to Kaggle, or create a new Kaggle account if you don't already have one.
  2. Go to the PaliGemma model card and click Request Access.
  3. Complete the consent form and accept the terms and conditions.

Configure your API key

To use PaliGemma, you must provide your Kaggle username and a Kaggle API key.

To generate a Kaggle API key, open your Settings page in Kaggle and click Create New Token. This triggers the download of a kaggle.json file containing your API credentials.

Then, in Colab, select Secrets (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name KAGGLE_USERNAME and your API key under the name KAGGLE_KEY.

Select the runtime

To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the PaliGemma model. In this case, you can use a T4 GPU:

  1. In the upper-right of the Colab window, click the ▾ (Additional connection options) dropdown menu.
  2. Select Change runtime type.
  3. Under Hardware accelerator, select T4 GPU.

Set environment variables

Set the environment variables for KAGGLE_USERNAME, KAGGLE_KEY, and KERAS_BACKEND.

import os
from google.colab import userdata

# Set up environmental variables
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"

Install Keras

Run the below cell to install Keras.

pip install -U -q keras-nlp

Import dependencies and configure Keras

Install the dependencies needed for this notebook and configure Keras' backend. You'll also set Keras to use bfloat16 so that the framework uses less memory.

import keras
import keras_nlp
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

keras.config.set_floatx("bfloat16")

Create your model

Now that you've set everything up, you can download the pre-trained model and create some utility methods to help your model generate its responses.

Download the model checkpoint

KerasNLP provides implementations of many popular model architectures. In this notebook, you'll create a model using PaliGemmaCausalLM, an end-to-end PaliGemma model for causal visual language modeling. A causal visual language model predicts the next token based on previous tokens.

Create the model using the from_preset method and print its summary. This process will take about a minute to complete.

paligemma = keras_nlp.models.PaliGemmaCausalLM.from_preset("pali_gemma_3b_mix_224")
paligemma.summary()

Create utility methods

To help you generate responses from your model, create two utility methods:

  • crop_and_resize: Helper method for read_img. This method crops and resizes the image to the passed in size so that the final image is resized without skewing the proportions of the image.
  • read_img: Helper method for read_img_from_url. This method is what actually opens the image, resizes it so that it fits in the model's constraints, and puts it into an array that can be interpreted by the model.
  • read_img_from_url: Takes in an image via a valid URL. You need this method to pass the image to the model.

You'll use read_img_from_url in the next step of this notebook.

def crop_and_resize(image, target_size):
    width, height = image.size
    source_size = min(image.size)
    left = width // 2 - source_size // 2
    top = height // 2 - source_size // 2
    right, bottom = left + source_size, top + source_size
    return image.resize(target_size, box=(left, top, right, bottom))

def read_image(url, target_size):
    contents = io.BytesIO(requests.get(url).content)
    image = PIL.Image.open(contents)
    image = crop_and_resize(image, target_size)
    image = np.array(image)
    # Remove alpha channel if neccessary.
    if image.shape[2] == 4:
        image = image[:, :, :3]
    return image

def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>.+?)( ;|$)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

def display_boxes(image, boxes, labels, target_image_size):
  h, l = target_size
  fig, ax = plt.subplots()
  ax.imshow(image)
  for i in range(boxes.shape[0]):
      y, x, y2, x2 = (boxes[i]*h)
      width = x2 - x
      height = y2 - y
      # Create a Rectangle patch
      rect = patches.Rectangle((x, y),
                               width,
                               height,
                               linewidth=1,
                               edgecolor='r',
                               facecolor='none')
      # Add label
      plt.text(x, y, labels[i], color='red', fontsize=12)
      # Add the patch to the Axes
      ax.add_patch(rect)

  plt.show()

def display_segment_output(image, segment_mask, target_image_size):
  # Calculate scaling factors
  h, w = target_image_size
  x_scale = w / 64
  y_scale = h / 64

  # Create coordinate grids for the new image
  x_coords = np.arange(w)
  y_coords = np.arange(h)
  x_coords = (x_coords / x_scale).astype(int)
  y_coords = (y_coords / y_scale).astype(int)
  resized_array = segment_mask[y_coords[:, np.newaxis], x_coords]
  # Create a figure and axis
  fig, ax = plt.subplots()

  # Display the image
  ax.imshow(image)

  # Overlay the mask with transparency
  ax.imshow(resized_array, cmap='jet', alpha=0.5)

Test your model

Now you're ready to give an image and prompt to your model and have it infer the response.

Lets look at our test image and read it

target_size = (224, 224)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
cow_image = read_image(image_url, target_size)
matplotlib.pyplot.imshow(cow_image)

Here's a generation call with a single image and prompt. The prompts have to end with a \n.

We've supplied you with several example prompts — play around with it! Comment and uncomment the prompt variables to change what prompt you supply the model with.

prompt = 'answer en where is the cow standing?\n'
# prompt = 'svar no hvor står kuen?'
# prompt = 'answer fr quelle couleur est le ciel?'
# prompt = 'responda pt qual a cor do animal?'
output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
print(output)

Here's a generation call with batched inputs.

prompts = [
    'answer en where is the cow standing?\n',
    'answer en what color is the cow?\n',
    'describe en\n',
    'detect cow\n',
    'segment cow\n',
]
images = [cow_image, cow_image, cow_image, cow_image, cow_image]
outputs = paligemma.generate(
    inputs={
        "images": images,
        "prompts": prompts,
    }
)
for output in outputs:
    print(output)

We've supplied you with several example prompts — play around with it! Comment and uncomment the prompt variables to change what prompt you supply the model with.

Other styles of prompts

You may have noticed in the previous step that the provided examples are in several different languages. PaliGemma supports language recognition for 34 different languages. You can find the list of supported languages on GitHub.

PaliGemma can handle several other prompt styles:

  • "cap {lang}\n": Very raw short caption (from WebLI-alt)
  • "caption {lang}\n": Nice, COCO-like short captions
  • "describe {lang}\n": Somewhat longer, more descriptive captions
  • "ocr": Optical character recognition
  • "answer en {question}\n": Question answering about the image contents
  • "question {lang} {answer}\n": Question generation for a given answer
  • "detect {object} ; {object}\n": Count objects in a scene and return the bounding boxes for the objects
  • "segment {object}\n": Do image segmentation of the object in the scene

Try them out!

Parse detect output

prompt = 'detect cow\n'
output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
boxes, labels = parse_bbox_and_labels(output)
display_boxes(cow_image, boxes, labels, target_size)

Parse segment output

Let's take a look at another example image.

cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)
matplotlib.pyplot.imshow(cat)

Here is a function to help parse the segment output from PaliGemma

import  big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      + ''.join(f'<seg(?P<s{i}>\d\d\d)>' for i in range(16)),
      detokenized_output,
  )
  boxes, segs = [], []
  fmt_box = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])
    segs.append([int(d[f's{i}']) for i in range(16)])
  return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))

Query PaliGemma to segment the cat in the image

prompt = 'segment cat\n'
output = paligemma.generate(
    inputs={
        "images": cat,
        "prompts": prompt,
    }
)

Visualize the generated mask from PaliGemma

_, seg_output = parse_segments(output)
display_segment_output(cat, seg_output[0], target_size)