使用 Keras 產生 PaliGemma 輸出內容

在 ai.google.dev 上查看 在 Google Colab 中執行 在 Vertex AI 中開啟 在 GitHub 上查看來源

PaliGemma 模型具備多模態功能,可根據文字和圖片輸入資料生成輸出內容。您可以搭配這些模型使用圖片資料,為要求提供額外背景資訊,也可以使用模型分析圖片內容。本教學課程說明如何搭配 Keras 使用 PaliGemma,分析圖片並回答相關問題。

這個筆記本的內容

本筆記本使用 Keras 搭配 PaliGemma,並說明如何:

  • 安裝 Keras 和必要依附元件
  • 下載 PaliGemmaCausalLM,這是預先訓練的 PaliGemma 變體,適用於因果視覺語言模型,並用來建立模型
  • 測試模型推斷所提供圖片資訊的能力

事前準備

在瀏覽這個筆記本之前,您應該熟悉 Python 程式碼,以及大型語言模型 (LLM) 的訓練方式。您不需要熟悉 Keras,但閱讀範例程式碼時,具備 Keras 的基本知識會很有幫助。

設定

以下章節說明如何取得筆記本,並使用 PaliGemma 模型,包括模型存取權、取得 API 金鑰,以及設定筆記本執行階段。

取得 PaliGemma 存取權

首次使用 PaliGemma 前,請先透過 Kaggle 要求存取模型,方法如下:

  1. 登入 Kaggle,或建立新的 Kaggle 帳戶 (如果還沒有)。
  2. 前往 PaliGemma 模型資訊卡,然後按一下「Request Access」(要求存取權)
  3. 填寫同意聲明表單,並接受條款及細則。

設定 API 金鑰

如要使用 PaliGemma,您必須提供 Kaggle 使用者名稱和 Kaggle API 金鑰。

如要產生 Kaggle API 金鑰,請在 Kaggle 中開啟「設定」頁面,然後按一下「Create New Token」(建立新權杖)。系統會下載含有 API 憑證的 kaggle.json 檔案。

接著在 Colab 中,選取左側窗格的「祕密」 (🔑),然後新增 Kaggle 使用者名稱和 Kaggle API 金鑰。將使用者名稱儲存為 KAGGLE_USERNAME,API 金鑰儲存為 KAGGLE_KEY

選取執行階段

如要完成本教學課程,您必須擁有 Colab 執行階段,且具備足夠資源來執行 PaliGemma 模型。在這種情況下,您可以使用 T4 GPU:

  1. 在 Colab 視窗的右上角,按一下「▾ (其他連線選項)」下拉式選單。
  2. 選取「變更執行階段類型」
  3. 在「硬體加速器」下方,選取「T4 GPU」

設定環境變數

設定 KAGGLE_USERNAMEKAGGLE_KEYKERAS_BACKEND 的環境變數。

import os
from google.colab import userdata

# Set up environmental variables
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"

安裝 Keras

執行下列儲存格來安裝 Keras。

pip install -U -q keras-nlp keras-hub kagglehub

匯入依附元件並設定 Keras

安裝這個筆記本所需的依附元件,並設定 Keras 的後端。您也會將 Keras 設為使用 bfloat16,讓架構耗用較少記憶體。

import keras
import keras_hub
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

keras.config.set_floatx("bfloat16")

載入模型

完成所有設定後,您就可以下載預先訓練的模型,並建立一些實用方法,協助模型生成回覆。 在這個步驟中,您會使用 Keras Hub 的 PaliGemmaCausalLM 下載模型。這個類別可協助您管理及執行 PaliGemma 的因果視覺語言模型結構。因果視覺語言模型會根據先前的符記預測下一個符記。Keras Hub 提供許多熱門模型架構的實作項目。

使用 from_preset 方法建立模型,並列印摘要。這項程序大約需要一分鐘才能完成。

paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset("kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224")
paligemma.summary()

建立公用程式方法

為協助您從模型生成回覆,請建立下列兩種公用程式方法:

  • crop_and_resizeread_img 的輔助方法。這個方法會裁剪圖片並調整大小,以符合傳入的大小,因此最終圖片會調整大小,但不會扭曲圖片比例。
  • read_imgread_img_from_url 的輔助方法。這個方法會實際開啟圖片、調整圖片大小以符合模型限制,並將圖片放入模型可解讀的陣列。
  • read_img_from_url透過有效網址取得圖片。您需要使用這個方法將圖片傳遞至模型。

您將在筆記本的下一個步驟中使用 read_img_from_url

def crop_and_resize(image, target_size):
    width, height = image.size
    source_size = min(image.size)
    left = width // 2 - source_size // 2
    top = height // 2 - source_size // 2
    right, bottom = left + source_size, top + source_size
    return image.resize(target_size, box=(left, top, right, bottom))

def read_image(url, target_size):
    contents = io.BytesIO(requests.get(url).content)
    image = PIL.Image.open(contents)
    image = crop_and_resize(image, target_size)
    image = np.array(image)
    # Remove alpha channel if necessary.
    if image.shape[2] == 4:
        image = image[:, :, :3]
    return image

def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>.+?)( ;|$)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

def display_boxes(image, boxes, labels, target_image_size):
  h, l = target_size
  fig, ax = plt.subplots()
  ax.imshow(image)
  for i in range(boxes.shape[0]):
      y, x, y2, x2 = (boxes[i]*h)
      width = x2 - x
      height = y2 - y
      # Create a Rectangle patch
      rect = patches.Rectangle((x, y),
                               width,
                               height,
                               linewidth=1,
                               edgecolor='r',
                               facecolor='none')
      # Add label
      plt.text(x, y, labels[i], color='red', fontsize=12)
      # Add the patch to the Axes
      ax.add_patch(rect)

  plt.show()

def display_segment_output(image, bounding_box, segment_mask, target_image_size):
    # Initialize a full mask with the target size
    full_mask = np.zeros(target_image_size, dtype=np.uint8)
    target_width, target_height = target_image_size

    for bbox, mask in zip(bounding_box, segment_mask):
        y1, x1, y2, x2 = bbox
        x1 = int(x1 * target_width)
        y1 = int(y1 * target_height)
        x2 = int(x2 * target_width)
        y2 = int(y2 * target_height)

        # Ensure mask is 2D before converting to Image
        if mask.ndim == 3:
            mask = mask.squeeze(axis=-1)
        mask = Image.fromarray(mask)
        mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)
        mask = np.array(mask)
        binary_mask = (mask > 0.5).astype(np.uint8)


        # Place the binary mask onto the full mask
        full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)
    cmap = plt.get_cmap('jet')
    colored_mask = cmap(full_mask / 1.0)
    colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)
    if isinstance(image, Image.Image):
        image = np.array(image)
    blended_image = image.copy()
    mask_indices = full_mask > 0
    alpha = 0.5

    for c in range(3):
        blended_image[:, :, c] = np.where(mask_indices,
                                          (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],
                                          image[:, :, c])

    fig, ax = plt.subplots()
    ax.imshow(blended_image)
    plt.show()

生成輸出內容

載入模型並建立公用程式方法後,您可以使用圖片和文字資料提示模型,生成回覆。PaliGemma 模型會使用特定提示語法訓練,以執行特定工作,例如 answercaptiondetect。如要進一步瞭解 PaliGemma 提示詞工作語法,請參閱「PaliGemma 提示詞和系統指令」。

準備圖片以用於生成提示,方法是使用下列程式碼將測試圖片載入物件:

target_size = (224, 224)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
cow_image = read_image(image_url, target_size)
matplotlib.pyplot.imshow(cow_image)

以特定語言回答

下列程式碼範例說明如何提示 PaliGemma 模型,提供所提供圖片中顯示的物件相關資訊。這個範例使用 answer {lang} 語法,並以其他語言顯示更多問題:

prompt = 'answer en where is the cow standing?\n'
# prompt = 'svar no hvor står kuen?\n'
# prompt = 'answer fr quelle couleur est le ciel?\n'
# prompt = 'responda pt qual a cor do animal?\n'

output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
print(output)

使用 detect 提示

以下程式碼範例使用 detect 提示語法,在提供的圖片中找出物件。程式碼會使用先前定義的 parse_bbox_and_labels()display_boxes() 函式解讀模型輸出內容,並顯示產生的邊界方塊。

prompt = 'detect cow\n'
output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
boxes, labels = parse_bbox_and_labels(output)
display_boxes(cow_image, boxes, labels, target_size)

使用 segment 提示

以下程式碼範例使用 segment 提示語法,找出圖片中物件所占的區域。這項工具會使用 Google big_vision 程式庫解讀模型輸出內容,並為區隔物件產生遮罩。

開始之前,請先安裝 big_vision 程式庫及其依附元件,如以下程式碼範例所示:

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")


# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

在這個區隔範例中,請載入並準備包含貓的圖片。

cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)
matplotlib.pyplot.imshow(cat)

以下函式有助於剖析 PaliGemma 的區段輸出內容

import  big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      + ''.join(f'<seg(?P<s{i}>\d\d\d)>' for i in range(16)),
      detokenized_output,
  )
  boxes, segs = [], []
  fmt_box = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])
    segs.append([int(d[f's{i}']) for i in range(16)])
  return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))

查詢 PaliGemma,將圖片中的貓分割出來

prompt = 'segment cat\n'
output = paligemma.generate(
    inputs={
        "images": cat,
        "prompts": prompt,
    }
)

以視覺化方式呈現 PaliGemma 生成的遮罩

bboxes, seg_masks = parse_segments(output)
display_segment_output(cat, bboxes, seg_masks, target_size)

批次提示

您可以在單一提示中提供多個提示指令,做為一組指令。以下範例說明如何建構提示文字,提供多項指令。

prompts = [
    'answer en where is the cow standing?\n',
    'answer en what color is the cow?\n',
    'describe en\n',
    'detect cow\n',
    'segment cow\n',
]
images = [cow_image, cow_image, cow_image, cow_image, cow_image]
outputs = paligemma.generate(
    inputs={
        "images": images,
        "prompts": prompts,
    }
)
for output in outputs:
    print(output)