使用 Keras 產生 PaliGemma 輸出內容

PaliGemma 模型具備多模態功能,可讓您使用文字和圖片輸入資料產生輸出內容。您可以搭配這些模型使用圖片資料,為要求提供其他背景資訊,或是使用模型分析圖片內容。本教學課程將說明如何使用 PaliGemma 搭配 Keras,分析圖片並回答相關問題。

本筆記本內容

本筆記本使用 PaliGemma 搭配 Keras,並說明如何:

  • 安裝 Keras 和必要的依附元件
  • 下載 PaliGemmaCausalLM,這是預先訓練的 PaliGemma 變體,可用於因果視覺語言建模,並使用該模型建立模型
  • 測試模型推斷所提供圖片資訊的能力

事前準備

在閱讀本筆記本之前,請先熟悉 Python 程式碼,以及大型語言模型 (LLM) 的訓練方式。您不需要熟悉 Keras,但在閱讀範例程式碼時,具備 Keras 的基本知識會很有幫助。

設定

以下各節將說明取得筆記本以使用 PaliGemma 模型的初步步驟,包括模型存取權、取得 API 金鑰,以及設定筆記本執行階段。

取得 PaliGemma 存取權

首次使用 PaliGemma 前,您必須完成下列步驟,透過 Kaggle 申請模型存取權:

  1. 登入 Kaggle,或建立新的 Kaggle 帳戶 (如果尚未建立)。
  2. 前往 PaliGemma 模型資訊卡,然後點選「要求存取權」
  3. 填妥同意聲明表單,並接受條款及細則。

設定 API 金鑰

如要使用 PaliGemma,您必須提供 Kaggle 使用者名稱和 Kaggle API 金鑰。

如要產生 Kaggle API 金鑰,請在 Kaggle 中開啟「設定頁面,然後按一下「建立新的權杖」。這會觸發下載含有 API 憑證的 kaggle.json 檔案。

接著,在 Colab 中選取左側窗格中的「Secrets」(「Secrets」圖示 🔑),然後新增 Kaggle 使用者名稱和 Kaggle API 金鑰。將使用者名稱儲存為 KAGGLE_USERNAME,將 API 金鑰儲存為 KAGGLE_KEY

選取執行階段

如要完成本教學課程,您需要具備 Colab 執行階段,並有足夠的資源來執行 PaliGemma 模型。在這種情況下,您可以使用 T4 GPU:

  1. 在 Colab 視窗的右上方,按一下「▾」 (其他連線選項) 下拉式選單。
  2. 選取「變更執行階段類型」
  3. 在「硬體加速器」下方,選取「T4 GPU」

設定環境變數

設定 KAGGLE_USERNAMEKAGGLE_KEYKERAS_BACKEND 的環境變數。

import os
from google.colab import userdata

# Set up environmental variables
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"

安裝 Keras

執行下方儲存格,安裝 Keras。

pip install -U -q keras-nlp keras-hub kagglehub

匯入依附元件並設定 Keras

安裝這個筆記本所需的依附元件,並設定 Keras 的後端。您也將設定 Keras 使用 bfloat16,讓架構使用較少記憶體。

import keras
import keras_hub
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

keras.config.set_floatx("bfloat16")

載入模型

設定完所有項目後,您可以下載預先訓練的模型,並建立一些公用程式方法,協助模型產生回覆。在這個步驟中,您會使用 PaliGemmaCausalLM 從 Keras Hub 下載模型。這個類別可協助您管理及執行 PaliGemma 的因果視覺語言模型結構。因果視覺語言模型會根據先前的符記預測下一個符記。Keras Hub 提供許多熱門模型架構的導入方法。

使用 from_preset 方法建立模型,並列印其摘要。這項程序約需 1 分鐘才能完成。

paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset("kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224")
paligemma.summary()

建立公用程式方法

為了協助您從模型產生回應,請建立兩種公用方法:

  • crop_and_resizeread_img 的輔助方法。這個方法會裁剪圖片並將圖片調整為傳入的大小,以便調整最終圖片大小,同時不扭曲圖片比例。
  • read_imgread_img_from_url 的輔助方法。這個方法會實際開啟圖片、調整圖片大小,以便符合模型的限制條件,並將圖片放入可供模型解讀的陣列。
  • read_img_from_url透過有效網址擷取圖片。您需要使用這個方法將圖片傳送至模型。

您將在本筆記的下一個步驟中使用 read_img_from_url

def crop_and_resize(image, target_size):
    width, height = image.size
    source_size = min(image.size)
    left = width // 2 - source_size // 2
    top = height // 2 - source_size // 2
    right, bottom = left + source_size, top + source_size
    return image.resize(target_size, box=(left, top, right, bottom))

def read_image(url, target_size):
    contents = io.BytesIO(requests.get(url).content)
    image = PIL.Image.open(contents)
    image = crop_and_resize(image, target_size)
    image = np.array(image)
    # Remove alpha channel if necessary.
    if image.shape[2] == 4:
        image = image[:, :, :3]
    return image

def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>.+?)( ;|$)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

def display_boxes(image, boxes, labels, target_image_size):
  h, l = target_size
  fig, ax = plt.subplots()
  ax.imshow(image)
  for i in range(boxes.shape[0]):
      y, x, y2, x2 = (boxes[i]*h)
      width = x2 - x
      height = y2 - y
      # Create a Rectangle patch
      rect = patches.Rectangle((x, y),
                               width,
                               height,
                               linewidth=1,
                               edgecolor='r',
                               facecolor='none')
      # Add label
      plt.text(x, y, labels[i], color='red', fontsize=12)
      # Add the patch to the Axes
      ax.add_patch(rect)

  plt.show()

def display_segment_output(image, bounding_box, segment_mask, target_image_size):
    # Initialize a full mask with the target size
    full_mask = np.zeros(target_image_size, dtype=np.uint8)
    target_width, target_height = target_image_size

    for bbox, mask in zip(bounding_box, segment_mask):
        y1, x1, y2, x2 = bbox
        x1 = int(x1 * target_width)
        y1 = int(y1 * target_height)
        x2 = int(x2 * target_width)
        y2 = int(y2 * target_height)

        # Ensure mask is 2D before converting to Image
        if mask.ndim == 3:
            mask = mask.squeeze(axis=-1)
        mask = Image.fromarray(mask)
        mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)
        mask = np.array(mask)
        binary_mask = (mask > 0.5).astype(np.uint8)


        # Place the binary mask onto the full mask
        full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)
    cmap = plt.get_cmap('jet')
    colored_mask = cmap(full_mask / 1.0)
    colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)
    if isinstance(image, Image.Image):
        image = np.array(image)
    blended_image = image.copy()
    mask_indices = full_mask > 0
    alpha = 0.5

    for c in range(3):
        blended_image[:, :, c] = np.where(mask_indices,
                                          (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],
                                          image[:, :, c])

    fig, ax = plt.subplots()
    ax.imshow(blended_image)
    plt.show()

產生輸出內容

載入模型並建立公用程式方法後,您可以使用圖片和文字資料提示模型,產生回應。PaliGemma 模型會針對特定工作 (例如 answercaptiondetect) 使用特定提示語法進行訓練。如要進一步瞭解 PaliGemma 提示工作語法,請參閱 PaliGemma 提示和系統指示

使用下列程式碼將測試圖片載入至物件,為產生提示準備圖片:

target_size = (224, 224)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
cow_image = read_image(image_url, target_size)
matplotlib.pyplot.imshow(cow_image)

以特定語言回覆

以下程式碼範例說明如何向 PaliGemma 模型提示,要求提供圖片中物件的相關資訊。這個範例使用 answer {lang} 語法,並顯示其他語言的額外問題:

prompt = 'answer en where is the cow standing?\n'
# prompt = 'svar no hvor står kuen?\n'
# prompt = 'answer fr quelle couleur est le ciel?\n'
# prompt = 'responda pt qual a cor do animal?\n'

output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
print(output)

使用 detect 提示

以下程式碼範例使用 detect 提示語法,在提供的圖片中找出物件。程式碼會使用先前定義的 parse_bbox_and_labels()display_boxes() 函式,解讀模型輸出內容並顯示產生的邊界框。

prompt = 'detect cow\n'
output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
boxes, labels = parse_bbox_and_labels(output)
display_boxes(cow_image, boxes, labels, target_size)

使用 segment 提示

以下程式碼範例使用 segment 提示語法,找出圖片中物件所佔據的區域。它會使用 Google big_vision 程式庫解讀模型輸出內容,並為已分割的物件產生遮罩。

開始之前,請先安裝 big_vision 程式庫及其依附元件,如以下程式碼範例所示:

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")


# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

在這個區隔範例中,請載入及準備另一張含有貓咪的圖片。

cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)
matplotlib.pyplot.imshow(cat)

以下是可協助剖析 PaliGemma 區段輸出的函式

import  big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      + ''.join(f'<seg(?P<s{i}>\d\d\d)>' for i in range(16)),
      detokenized_output,
  )
  boxes, segs = [], []
  fmt_box = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])
    segs.append([int(d[f's{i}']) for i in range(16)])
  return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))

查詢 PaliGemma,將圖片中的貓咪區隔開來

prompt = 'segment cat\n'
output = paligemma.generate(
    inputs={
        "images": cat,
        "prompts": prompt,
    }
)

以視覺化方式呈現 PaliGemma 產生的遮罩

bboxes, seg_masks = parse_segments(output)
display_segment_output(cat, bboxes, seg_masks, target_size)

批次提示

您可以在單一提示中提供多個提示指令,做為一組指示。以下範例說明如何安排提示文字,提供多項操作說明。

prompts = [
    'answer en where is the cow standing?\n',
    'answer en what color is the cow?\n',
    'describe en\n',
    'detect cow\n',
    'segment cow\n',
]
images = [cow_image, cow_image, cow_image, cow_image, cow_image]
outputs = paligemma.generate(
    inputs={
        "images": images,
        "prompts": prompts,
    }
)
for output in outputs:
    print(output)