PaliGemma 模型具有多模态功能,可让您同时使用文本和图片输入数据生成输出。您可以将图片数据与这些模型搭配使用,为请求提供更多背景信息,或使用模型分析图片内容。本教程介绍了如何将 PaliGemma 与 Keras 搭配使用,以便分析图片并回答与图片相关的问题。
此手册的内容
此笔记本将 PaliGemma 与 Keras 搭配使用,并向您展示如何:
- 安装 Keras 和所需的依赖项
- 下载
PaliGemmaCausalLM
,这是一个用于因果视觉语言建模的预训练 PaliGemma 变体,并使用它创建模型 - 测试模型推断所提供图片的相关信息的能力
准备工作
在学习本记事之前,您应熟悉 Python 代码以及大语言模型 (LLM) 的训练方式。您无需熟悉 Keras,但在阅读示例代码时,了解 Keras 的基本知识会很有帮助。
设置
以下部分介绍了让笔记本使用 PaliGemma 模型的初始步骤,包括模型访问权限、获取 API 密钥和配置笔记本运行时。
获取 PaliGemma 的访问权限
在首次使用 PaliGemma 之前,您必须完成以下步骤,通过 Kaggle 申请访问该模型的权限:
- 登录 Kaggle,或者创建一个新的 Kaggle 账号(如果您还没有)。
- 前往 PaliGemma 模型卡片,然后点击申请访问权限。
- 填写同意书并接受条款及条件。
配置 API 密钥
如需使用 PaliGemma,您必须提供 Kaggle 用户名和 Kaggle API 密钥。
如需生成 Kaggle API 密钥,请在 Kaggle 中打开设置页面,然后点击创建新令牌。这会触发下载包含您的 API 凭据的 kaggle.json
文件。
然后,在 Colab 中,选择左侧窗格中的 Secrets(密钥)图标 (🔑),然后添加您的 Kaggle 用户名和 Kaggle API 密钥。将您的用户名存储在名称为 KAGGLE_USERNAME
的变量下,将您的 API 密钥存储在名称为 KAGGLE_KEY
的变量下。
选择运行时
如需完成本教程,您需要拥有一个具有足够资源的 Colab 运行时,以便运行 PaliGemma 模型。在这种情况下,您可以使用 T4 GPU:
- 在 Colab 窗口的右上角,点击 ▾(其他连接选项)下拉菜单。
- 选择更改运行时类型。
- 在硬件加速器下,选择 T4 GPU。
设置环境变量
为 KAGGLE_USERNAME
、KAGGLE_KEY
和 KERAS_BACKEND
设置环境变量。
import os
from google.colab import userdata
# Set up environmental variables
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"
安装 Keras
运行以下单元格以安装 Keras。
pip install -U -q keras-nlp keras-hub kagglehub
导入依赖项并配置 Keras
安装此记事本所需的依赖项,并配置 Keras 的后端。您还将设置 Keras 以使用 bfloat16
,以便该框架使用更少的内存。
import keras
import keras_hub
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
keras.config.set_floatx("bfloat16")
加载模型
现在,您已完成所有设置,可以下载预训练模型并创建一些实用方法来帮助模型生成回答。
在此步骤中,您将使用 PaliGemmaCausalLM
从 Keras Hub 下载模型。此类可帮助您管理和运行 PaliGemma 的因果视觉语言模型结构。因果视觉语言模型可根据之前的令牌预测下一个令牌。Keras Hub 提供了许多热门模型架构的实现。
使用 from_preset
方法创建模型并输出其摘要。此过程大约需要一分钟才能完成。
paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset("kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224")
paligemma.summary()
创建实用方法
为了帮助您从模型生成回答,请创建两个实用方法:
crop_and_resize
:read_img
的辅助方法。此方法会将图片剪裁并调整为传入的大小,以便调整后的最终图片不会使图片的比例失真。read_img
:read_img_from_url
的辅助方法。此方法实际上会打开图片,调整其大小以使其符合模型的约束条件,并将其放入可由模型解读的数组中。read_img_from_url
:通过有效网址接收图片。您需要使用此方法将图片传递给模型。
您将在本记事的下一步中使用 read_img_from_url
。
def crop_and_resize(image, target_size):
width, height = image.size
source_size = min(image.size)
left = width // 2 - source_size // 2
top = height // 2 - source_size // 2
right, bottom = left + source_size, top + source_size
return image.resize(target_size, box=(left, top, right, bottom))
def read_image(url, target_size):
contents = io.BytesIO(requests.get(url).content)
image = PIL.Image.open(contents)
image = crop_and_resize(image, target_size)
image = np.array(image)
# Remove alpha channel if necessary.
if image.shape[2] == 4:
image = image[:, :, :3]
return image
def parse_bbox_and_labels(detokenized_output: str):
matches = re.finditer(
'<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
' (?P<label>.+?)( ;|$)',
detokenized_output,
)
labels, boxes = [], []
fmt = lambda x: float(x) / 1024.0
for m in matches:
d = m.groupdict()
boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
labels.append(d['label'])
return np.array(boxes), np.array(labels)
def display_boxes(image, boxes, labels, target_image_size):
h, l = target_size
fig, ax = plt.subplots()
ax.imshow(image)
for i in range(boxes.shape[0]):
y, x, y2, x2 = (boxes[i]*h)
width = x2 - x
height = y2 - y
# Create a Rectangle patch
rect = patches.Rectangle((x, y),
width,
height,
linewidth=1,
edgecolor='r',
facecolor='none')
# Add label
plt.text(x, y, labels[i], color='red', fontsize=12)
# Add the patch to the Axes
ax.add_patch(rect)
plt.show()
def display_segment_output(image, bounding_box, segment_mask, target_image_size):
# Initialize a full mask with the target size
full_mask = np.zeros(target_image_size, dtype=np.uint8)
target_width, target_height = target_image_size
for bbox, mask in zip(bounding_box, segment_mask):
y1, x1, y2, x2 = bbox
x1 = int(x1 * target_width)
y1 = int(y1 * target_height)
x2 = int(x2 * target_width)
y2 = int(y2 * target_height)
# Ensure mask is 2D before converting to Image
if mask.ndim == 3:
mask = mask.squeeze(axis=-1)
mask = Image.fromarray(mask)
mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)
mask = np.array(mask)
binary_mask = (mask > 0.5).astype(np.uint8)
# Place the binary mask onto the full mask
full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)
cmap = plt.get_cmap('jet')
colored_mask = cmap(full_mask / 1.0)
colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)
if isinstance(image, Image.Image):
image = np.array(image)
blended_image = image.copy()
mask_indices = full_mask > 0
alpha = 0.5
for c in range(3):
blended_image[:, :, c] = np.where(mask_indices,
(1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],
image[:, :, c])
fig, ax = plt.subplots()
ax.imshow(blended_image)
plt.show()
生成输出
加载模型并创建实用方法后,您可以使用图片和文本数据提示模型以生成回答。PaliGemma 模型是使用特定任务的特定提示语法进行训练的,例如 answer
、caption
和 detect
。如需详细了解 PaliGemma 提示任务语法,请参阅 PaliGemma 提示和系统说明。
使用以下代码将测试图片加载到对象中,以便在生成提示中使用该图片:
target_size = (224, 224)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
cow_image = read_image(image_url, target_size)
matplotlib.pyplot.imshow(cow_image)
以特定语言回答
以下示例代码展示了如何提示 PaliGemma 模型提供有关所提供图片中显示的对象的信息。以下示例使用 answer {lang}
语法,并以其他语言显示其他问题:
prompt = 'answer en where is the cow standing?\n'
# prompt = 'svar no hvor står kuen?\n'
# prompt = 'answer fr quelle couleur est le ciel?\n'
# prompt = 'responda pt qual a cor do animal?\n'
output = paligemma.generate(
inputs={
"images": cow_image,
"prompts": prompt,
}
)
print(output)
使用 detect
提示
以下示例代码使用 detect
提示语法在提供的图片中定位对象。该代码使用之前定义的 parse_bbox_and_labels()
和 display_boxes()
函数来解读模型输出并显示生成的边界框。
prompt = 'detect cow\n'
output = paligemma.generate(
inputs={
"images": cow_image,
"prompts": prompt,
}
)
boxes, labels = parse_bbox_and_labels(output)
display_boxes(cow_image, boxes, labels, target_size)
使用 segment
提示
以下示例代码使用 segment
提示语法来定位对象占据的图片区域。它使用 Google big_vision
库来解读模型输出,并为分割对象生成遮罩。
开始之前,请先安装 big_vision
库及其依赖项,如以下代码示例所示:
import os
import sys
# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
raise "It seems you are using Colab with remote TPUs which is not supported."
# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
!git clone --quiet --branch=main --depth=1 \
https://github.com/google-research/big_vision big_vision_repo
# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
sys.path.append("big_vision_repo")
# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"
对于此分割示例,请加载并准备另一张包含猫的图片。
cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)
matplotlib.pyplot.imshow(cat)
以下函数可帮助解析 PaliGemma 中的片段输出
import big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
matches = re.finditer(
'<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
+ ''.join(f'<seg(?P<s{i}>\d\d\d)>' for i in range(16)),
detokenized_output,
)
boxes, segs = [], []
fmt_box = lambda x: float(x) / 1024.0
for m in matches:
d = m.groupdict()
boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])
segs.append([int(d[f's{i}']) for i in range(16)])
return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))
查询 PaliGemma 以分割图片中的猫
prompt = 'segment cat\n'
output = paligemma.generate(
inputs={
"images": cat,
"prompts": prompt,
}
)
直观呈现 PaliGemma 生成的遮罩
bboxes, seg_masks = parse_segments(output)
display_segment_output(cat, bboxes, seg_masks, target_size)
批量提示
您可以在一个提示中提供多个提示命令作为一组指令。以下示例展示了如何构建提示文本以提供多个说明。
prompts = [
'answer en where is the cow standing?\n',
'answer en what color is the cow?\n',
'describe en\n',
'detect cow\n',
'segment cow\n',
]
images = [cow_image, cow_image, cow_image, cow_image, cow_image]
outputs = paligemma.generate(
inputs={
"images": images,
"prompts": prompts,
}
)
for output in outputs:
print(output)