使用 Keras 生成 PaliGemma 输出

在 ai.google.dev 上查看 在 Google Colab 中运行 在 Vertex AI 中打开 在 GitHub 上查看源代码

PaliGemma 模型具有多模态功能,可让您使用文本和图片输入数据生成输出。您可以将图片数据与这些模型搭配使用,为请求提供更多背景信息,也可以使用这些模型来分析图片的内容。本教程将介绍如何使用 PaliGemma 和 Keras 分析图片并回答有关图片的问题。

本笔记本中的内容

此笔记本使用 PaliGemma 和 Keras,并向您展示了如何:

  • 安装 Keras 和必需的依赖项
  • 下载 PaliGemmaCausalLM(一种预训练的 PaliGemma 变体,用于因果视觉语言建模),并使用它来创建模型
  • 测试模型推断所提供图片相关信息的能力

准备工作

在学习本笔记本之前,您应该熟悉 Python 代码以及大语言模型 (LLM) 的训练方式。您无需熟悉 Keras,但如果具备 Keras 的基本知识,在阅读示例代码时会很有帮助。

设置

以下部分介绍了让笔记本使用 PaliGemma 模型所需的初步步骤,包括模型访问权限、获取 API 密钥和配置笔记本运行时。

获取 PaliGemma 访问权限

首次使用 PaliGemma 之前,您必须通过 Kaggle 申请对该模型的访问权限,具体步骤如下:

  1. 登录 Kaggle;如果您还没有 Kaggle 账号,请创建一个新账号。
  2. 前往 PaliGemma 模型卡片,然后点击申请访问权限
  3. 填写同意书并接受条款及条件。

配置 API 密钥

如需使用 PaliGemma,您必须提供 Kaggle 用户名和 Kaggle API 密钥。

如需生成 Kaggle API 密钥,请在 Kaggle 中打开设置页面,然后点击 Create New Token(创建新令牌)。此操作会触发下载一个包含您的 API 凭据的 kaggle.json 文件。

然后,在 Colab 中,选择左侧窗格中的 Secrets (🔑),并添加您的 Kaggle 用户名和 Kaggle API 密钥。将您的用户名存储在名称为 KAGGLE_USERNAME 的变量中,并将您的 API 密钥存储在名称为 KAGGLE_KEY 的变量中。

选择运行时

如需完成本教程,您需要一个具有足够资源来运行 PaliGemma 模型的 Colab 运行时。在这种情况下,您可以使用 T4 GPU:

  1. 在 Colab 窗口的右上角,点击 ▾(其他连接选项)下拉菜单。
  2. 选择更改运行时类型
  3. 硬件加速器下,选择 T4 GPU

设置环境变量

KAGGLE_USERNAMEKAGGLE_KEYKERAS_BACKEND 设置环境变量。

import os
from google.colab import userdata

# Set up environmental variables
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"

安装 Keras

运行以下单元格以安装 Keras。

pip install -U -q keras-nlp keras-hub kagglehub

导入依赖项并配置 Keras

安装此笔记本所需的依赖项并配置 Keras 的后端。您还将设置 Keras 使用 bfloat16,以便框架使用更少的内存。

import keras
import keras_hub
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

keras.config.set_floatx("bfloat16")

加载模型

现在,您已完成所有设置,可以下载预训练模型并创建一些实用方法,以帮助模型生成回答。 在此步骤中,您将使用 PaliGemmaCausalLM 从 Keras Hub 下载模型。此类有助于管理和运行 PaliGemma 的因果视觉语言模型结构。因果视觉语言模型会根据之前的 token 预测下一个 token。Keras Hub 提供了许多热门模型架构的实现。

使用 from_preset 方法创建模型并打印其摘要。此过程大约需要一分钟才能完成。

paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset("kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224")
paligemma.summary()

创建实用程序方法

为了帮助您从模型生成回答,请创建两个实用方法:

  • crop_and_resizeread_img 的辅助方法。此方法会裁剪图片并将其调整为传入的大小,以便在不扭曲图片比例的情况下调整最终图片的大小。
  • read_imgread_img_from_url 的辅助方法。此方法实际上会打开图片,调整其大小以符合模型的限制,并将其放入可供模型解读的数组中。
  • read_img_from_url:通过有效网址获取图片。您需要使用此方法将图片传递给模型。

您将在本笔记本的下一步中使用 read_img_from_url

def crop_and_resize(image, target_size):
    width, height = image.size
    source_size = min(image.size)
    left = width // 2 - source_size // 2
    top = height // 2 - source_size // 2
    right, bottom = left + source_size, top + source_size
    return image.resize(target_size, box=(left, top, right, bottom))

def read_image(url, target_size):
    contents = io.BytesIO(requests.get(url).content)
    image = PIL.Image.open(contents)
    image = crop_and_resize(image, target_size)
    image = np.array(image)
    # Remove alpha channel if necessary.
    if image.shape[2] == 4:
        image = image[:, :, :3]
    return image

def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>.+?)( ;|$)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

def display_boxes(image, boxes, labels, target_image_size):
  h, l = target_size
  fig, ax = plt.subplots()
  ax.imshow(image)
  for i in range(boxes.shape[0]):
      y, x, y2, x2 = (boxes[i]*h)
      width = x2 - x
      height = y2 - y
      # Create a Rectangle patch
      rect = patches.Rectangle((x, y),
                               width,
                               height,
                               linewidth=1,
                               edgecolor='r',
                               facecolor='none')
      # Add label
      plt.text(x, y, labels[i], color='red', fontsize=12)
      # Add the patch to the Axes
      ax.add_patch(rect)

  plt.show()

def display_segment_output(image, bounding_box, segment_mask, target_image_size):
    # Initialize a full mask with the target size
    full_mask = np.zeros(target_image_size, dtype=np.uint8)
    target_width, target_height = target_image_size

    for bbox, mask in zip(bounding_box, segment_mask):
        y1, x1, y2, x2 = bbox
        x1 = int(x1 * target_width)
        y1 = int(y1 * target_height)
        x2 = int(x2 * target_width)
        y2 = int(y2 * target_height)

        # Ensure mask is 2D before converting to Image
        if mask.ndim == 3:
            mask = mask.squeeze(axis=-1)
        mask = Image.fromarray(mask)
        mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)
        mask = np.array(mask)
        binary_mask = (mask > 0.5).astype(np.uint8)


        # Place the binary mask onto the full mask
        full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)
    cmap = plt.get_cmap('jet')
    colored_mask = cmap(full_mask / 1.0)
    colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)
    if isinstance(image, Image.Image):
        image = np.array(image)
    blended_image = image.copy()
    mask_indices = full_mask > 0
    alpha = 0.5

    for c in range(3):
        blended_image[:, :, c] = np.where(mask_indices,
                                          (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],
                                          image[:, :, c])

    fig, ax = plt.subplots()
    ax.imshow(blended_image)
    plt.show()

生成输出

加载模型并创建实用程序方法后,您可以使用图片和文本数据提示模型生成回答。PaliGemma 模型经过训练,可使用特定提示语法来执行特定任务,例如 answercaptiondetect。如需详细了解 PaliGemma 提示任务语法,请参阅 PaliGemma 提示和系统指令

使用以下代码将测试图片加载到对象中,以便在生成提示中使用该图片:

target_size = (224, 224)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
cow_image = read_image(image_url, target_size)
matplotlib.pyplot.imshow(cow_image)

以特定语言回答

以下示例代码展示了如何提示 PaliGemma 模型提供有关所提供图片中显示的对象的信息。此示例使用 answer {lang} 语法,并以其他语言显示更多问题:

prompt = 'answer en where is the cow standing?\n'
# prompt = 'svar no hvor står kuen?\n'
# prompt = 'answer fr quelle couleur est le ciel?\n'
# prompt = 'responda pt qual a cor do animal?\n'

output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
print(output)

使用 detect 提示

以下示例代码使用 detect 提示语法在提供的图片中定位对象。该代码使用之前定义的 parse_bbox_and_labels()display_boxes() 函数来解读模型输出并显示生成的边界框。

prompt = 'detect cow\n'
output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
boxes, labels = parse_bbox_and_labels(output)
display_boxes(cow_image, boxes, labels, target_size)

使用 segment 提示

以下示例代码使用 segment 提示语法来定位图片中对象所占的区域。它使用 Google big_vision 库来解读模型输出,并为分割的对象生成遮罩。

在开始之前,请安装 big_vision 库及其依赖项,如以下代码示例所示:

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")


# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

对于此细分示例,请加载并准备一张包含猫的其他图片。

cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)
matplotlib.pyplot.imshow(cat)

以下函数有助于解析 PaliGemma 的分段输出

import  big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      + ''.join(f'<seg(?P<s{i}>\d\d\d)>' for i in range(16)),
      detokenized_output,
  )
  boxes, segs = [], []
  fmt_box = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])
    segs.append([int(d[f's{i}']) for i in range(16)])
  return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))

查询 PaliGemma 以分割图片中的猫

prompt = 'segment cat\n'
output = paligemma.generate(
    inputs={
        "images": cat,
        "prompts": prompt,
    }
)

直观呈现 PaliGemma 生成的遮盖

bboxes, seg_masks = parse_segments(output)
display_segment_output(cat, bboxes, seg_masks, target_size)

批量提示

您可以在单个提示中提供多个提示命令,作为一批指令。以下示例展示了如何构建提示文本以提供多条指令。

prompts = [
    'answer en where is the cow standing?\n',
    'answer en what color is the cow?\n',
    'describe en\n',
    'detect cow\n',
    'segment cow\n',
]
images = [cow_image, cow_image, cow_image, cow_image, cow_image]
outputs = paligemma.generate(
    inputs={
        "images": images,
        "prompts": prompts,
    }
)
for output in outputs:
    print(output)