使用 Keras 生成 PaliGemma 输出

PaliGemma 模型具有多模态功能,可让您同时使用文本和图片输入数据生成输出。您可以将图片数据与这些模型搭配使用,为请求提供更多背景信息,或使用模型分析图片内容。本教程介绍了如何将 PaliGemma 与 Keras 搭配使用,以便分析图片并回答与图片相关的问题。

此手册的内容

此笔记本将 PaliGemma 与 Keras 搭配使用,并向您展示如何:

  • 安装 Keras 和所需的依赖项
  • 下载 PaliGemmaCausalLM,这是一个用于因果视觉语言建模的预训练 PaliGemma 变体,并使用它创建模型
  • 测试模型推断所提供图片的相关信息的能力

准备工作

在学习本记事之前,您应熟悉 Python 代码以及大语言模型 (LLM) 的训练方式。您无需熟悉 Keras,但在阅读示例代码时,了解 Keras 的基本知识会很有帮助。

设置

以下部分介绍了让笔记本使用 PaliGemma 模型的初始步骤,包括模型访问权限、获取 API 密钥和配置笔记本运行时。

获取 PaliGemma 的访问权限

在首次使用 PaliGemma 之前,您必须完成以下步骤,通过 Kaggle 申请访问该模型的权限:

  1. 登录 Kaggle,或者创建一个新的 Kaggle 账号(如果您还没有)。
  2. 前往 PaliGemma 模型卡片,然后点击申请访问权限
  3. 填写同意书并接受条款及条件。

配置 API 密钥

如需使用 PaliGemma,您必须提供 Kaggle 用户名和 Kaggle API 密钥。

如需生成 Kaggle API 密钥,请在 Kaggle 中打开设置页面,然后点击创建新令牌。这会触发下载包含您的 API 凭据的 kaggle.json 文件。

然后,在 Colab 中,选择左侧窗格中的 Secrets(密钥)图标 (🔑),然后添加您的 Kaggle 用户名和 Kaggle API 密钥。将您的用户名存储在名称为 KAGGLE_USERNAME 的变量下,将您的 API 密钥存储在名称为 KAGGLE_KEY 的变量下。

选择运行时

如需完成本教程,您需要拥有一个具有足够资源的 Colab 运行时,以便运行 PaliGemma 模型。在这种情况下,您可以使用 T4 GPU:

  1. 在 Colab 窗口的右上角,点击 ▾(其他连接选项)下拉菜单。
  2. 选择更改运行时类型
  3. 硬件加速器下,选择 T4 GPU

设置环境变量

KAGGLE_USERNAMEKAGGLE_KEYKERAS_BACKEND 设置环境变量。

import os
from google.colab import userdata

# Set up environmental variables
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"

安装 Keras

运行以下单元格以安装 Keras。

pip install -U -q keras-nlp keras-hub kagglehub

导入依赖项并配置 Keras

安装此记事本所需的依赖项,并配置 Keras 的后端。您还将设置 Keras 以使用 bfloat16,以便该框架使用更少的内存。

import keras
import keras_hub
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

keras.config.set_floatx("bfloat16")

加载模型

现在,您已完成所有设置,可以下载预训练模型并创建一些实用方法来帮助模型生成回答。 在此步骤中,您将使用 PaliGemmaCausalLM 从 Keras Hub 下载模型。此类可帮助您管理和运行 PaliGemma 的因果视觉语言模型结构。因果视觉语言模型可根据之前的令牌预测下一个令牌。Keras Hub 提供了许多热门模型架构的实现。

使用 from_preset 方法创建模型并输出其摘要。此过程大约需要一分钟才能完成。

paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset("kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224")
paligemma.summary()

创建实用方法

为了帮助您从模型生成回答,请创建两个实用方法:

  • crop_and_resizeread_img 的辅助方法。此方法会将图片剪裁并调整为传入的大小,以便调整后的最终图片不会使图片的比例失真。
  • read_imgread_img_from_url 的辅助方法。此方法实际上会打开图片,调整其大小以使其符合模型的约束条件,并将其放入可由模型解读的数组中。
  • read_img_from_url:通过有效网址接收图片。您需要使用此方法将图片传递给模型。

您将在本记事的下一步中使用 read_img_from_url

def crop_and_resize(image, target_size):
    width, height = image.size
    source_size = min(image.size)
    left = width // 2 - source_size // 2
    top = height // 2 - source_size // 2
    right, bottom = left + source_size, top + source_size
    return image.resize(target_size, box=(left, top, right, bottom))

def read_image(url, target_size):
    contents = io.BytesIO(requests.get(url).content)
    image = PIL.Image.open(contents)
    image = crop_and_resize(image, target_size)
    image = np.array(image)
    # Remove alpha channel if necessary.
    if image.shape[2] == 4:
        image = image[:, :, :3]
    return image

def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>.+?)( ;|$)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

def display_boxes(image, boxes, labels, target_image_size):
  h, l = target_size
  fig, ax = plt.subplots()
  ax.imshow(image)
  for i in range(boxes.shape[0]):
      y, x, y2, x2 = (boxes[i]*h)
      width = x2 - x
      height = y2 - y
      # Create a Rectangle patch
      rect = patches.Rectangle((x, y),
                               width,
                               height,
                               linewidth=1,
                               edgecolor='r',
                               facecolor='none')
      # Add label
      plt.text(x, y, labels[i], color='red', fontsize=12)
      # Add the patch to the Axes
      ax.add_patch(rect)

  plt.show()

def display_segment_output(image, bounding_box, segment_mask, target_image_size):
    # Initialize a full mask with the target size
    full_mask = np.zeros(target_image_size, dtype=np.uint8)
    target_width, target_height = target_image_size

    for bbox, mask in zip(bounding_box, segment_mask):
        y1, x1, y2, x2 = bbox
        x1 = int(x1 * target_width)
        y1 = int(y1 * target_height)
        x2 = int(x2 * target_width)
        y2 = int(y2 * target_height)

        # Ensure mask is 2D before converting to Image
        if mask.ndim == 3:
            mask = mask.squeeze(axis=-1)
        mask = Image.fromarray(mask)
        mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)
        mask = np.array(mask)
        binary_mask = (mask > 0.5).astype(np.uint8)


        # Place the binary mask onto the full mask
        full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)
    cmap = plt.get_cmap('jet')
    colored_mask = cmap(full_mask / 1.0)
    colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)
    if isinstance(image, Image.Image):
        image = np.array(image)
    blended_image = image.copy()
    mask_indices = full_mask > 0
    alpha = 0.5

    for c in range(3):
        blended_image[:, :, c] = np.where(mask_indices,
                                          (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],
                                          image[:, :, c])

    fig, ax = plt.subplots()
    ax.imshow(blended_image)
    plt.show()

生成输出

加载模型并创建实用方法后,您可以使用图片和文本数据提示模型以生成回答。PaliGemma 模型是使用特定任务的特定提示语法进行训练的,例如 answercaptiondetect。如需详细了解 PaliGemma 提示任务语法,请参阅 PaliGemma 提示和系统说明

使用以下代码将测试图片加载到对象中,以便在生成提示中使用该图片:

target_size = (224, 224)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
cow_image = read_image(image_url, target_size)
matplotlib.pyplot.imshow(cow_image)

以特定语言回答

以下示例代码展示了如何提示 PaliGemma 模型提供有关所提供图片中显示的对象的信息。以下示例使用 answer {lang} 语法,并以其他语言显示其他问题:

prompt = 'answer en where is the cow standing?\n'
# prompt = 'svar no hvor står kuen?\n'
# prompt = 'answer fr quelle couleur est le ciel?\n'
# prompt = 'responda pt qual a cor do animal?\n'

output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
print(output)

使用 detect 提示

以下示例代码使用 detect 提示语法在提供的图片中定位对象。该代码使用之前定义的 parse_bbox_and_labels()display_boxes() 函数来解读模型输出并显示生成的边界框。

prompt = 'detect cow\n'
output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
boxes, labels = parse_bbox_and_labels(output)
display_boxes(cow_image, boxes, labels, target_size)

使用 segment 提示

以下示例代码使用 segment 提示语法来定位对象占据的图片区域。它使用 Google big_vision 库来解读模型输出,并为分割对象生成遮罩。

开始之前,请先安装 big_vision 库及其依赖项,如以下代码示例所示:

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")


# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

对于此分割示例,请加载并准备另一张包含猫的图片。

cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)
matplotlib.pyplot.imshow(cat)

以下函数可帮助解析 PaliGemma 中的片段输出

import  big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      + ''.join(f'<seg(?P<s{i}>\d\d\d)>' for i in range(16)),
      detokenized_output,
  )
  boxes, segs = [], []
  fmt_box = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])
    segs.append([int(d[f's{i}']) for i in range(16)])
  return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))

查询 PaliGemma 以分割图片中的猫

prompt = 'segment cat\n'
output = paligemma.generate(
    inputs={
        "images": cat,
        "prompts": prompt,
    }
)

直观呈现 PaliGemma 生成的遮罩

bboxes, seg_masks = parse_segments(output)
display_segment_output(cat, bboxes, seg_masks, target_size)

批量提示

您可以在一个提示中提供多个提示命令作为一组指令。以下示例展示了如何构建提示文本以提供多个说明。

prompts = [
    'answer en where is the cow standing?\n',
    'answer en what color is the cow?\n',
    'describe en\n',
    'detect cow\n',
    'segment cow\n',
]
images = [cow_image, cow_image, cow_image, cow_image, cow_image]
outputs = paligemma.generate(
    inputs={
        "images": images,
        "prompts": prompts,
    }
)
for output in outputs:
    print(output)