Keras を使用して PaliGemma の出力を生成する

PaliGemma モデルにはマルチモーダル機能があり、テキストと画像の両方の入力データを使用して出力を生成できます。これらのモデルで画像データを使用してリクエストに追加のコンテキストを提供したり、モデルを使用して画像のコンテンツを分析したりできます。このチュートリアルでは、Keras で PaliGemma を使用して画像を分析し、画像に関する質問に回答する方法について説明します。

このノートブックの内容

このノートブックでは、Keras で PaliGemma を使用し、次の方法を示します。

  • Keras と必要な依存関係をインストールする
  • 因果関係のある視覚言語モデル用に事前トレーニング済みの PaliGemma バリアントである PaliGemmaCausalLM をダウンロードし、モデルの作成に使用します。
  • 指定された画像に関する情報を推測するモデルの能力をテストする

始める前に

このノートブックを使用する前に、Python コードと大規模言語モデル(LLM)のトレーニング方法を理解しておく必要があります。Keras に精通している必要はありませんが、サンプルコードを読む際に Keras に関する基本的な知識があると役に立ちます。

セットアップ

以降のセクションでは、モデルへのアクセス、API キーの取得、ノートブック ランタイムの構成など、ノートブックで PaliGemma モデルを使用するための準備手順について説明します。

PaliGemma にアクセスする

PaliGemma を初めて使用する前に、次の手順に沿って Kaggle からモデルへのアクセスをリクエストする必要があります。

  1. Kaggle にログインします。まだアカウントを作成していない場合は、新しい Kaggle アカウントを作成します。
  2. PaliGemma モデルカードに移動し、[アクセスをリクエスト] をクリックします。
  3. 同意フォームに記入して利用規約に同意します。

API キーを設定する

PaliGemma を使用するには、Kaggle のユーザー名と Kaggle API キーを指定する必要があります。

Kaggle API キーを生成するには、Kaggle の [設定] ページを開き、[新しいトークンを作成] をクリックします。これにより、API 認証情報を含む kaggle.json ファイルのダウンロードがトリガーされます。

次に、Colab の左側のペインで [シークレット](🔑)を選択し、Kaggle ユーザー名と Kaggle API キーを追加します。ユーザー名を KAGGLE_USERNAME という名前で、API キーを KAGGLE_KEY という名前で保存します。

ランタイムの選択

このチュートリアルを完了するには、PaliGemma モデルを実行するのに十分なリソースがある Colab ランタイムが必要です。この場合は、T4 GPU を使用できます。

  1. Colab ウィンドウの右上にある [▾(その他の接続オプション)] プルダウン メニューをクリックします。
  2. [ランタイム タイプの変更] を選択します。
  3. [ハードウェア アクセラレータ] で [T4 GPU] を選択します。

環境変数を設定する

KAGGLE_USERNAMEKAGGLE_KEYKERAS_BACKEND の環境変数を設定します。

import os
from google.colab import userdata

# Set up environmental variables
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"

Keras をインストールする

次のセルを実行して Keras をインストールします。

pip install -U -q keras-nlp keras-hub kagglehub

依存関係をインポートして Keras を構成する

このノートブックに必要な依存関係をインストールし、Keras のバックエンドを構成します。また、フレームワークで使用するメモリを減らすために、Keras で bfloat16 を使用するように設定します。

import keras
import keras_hub
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

keras.config.set_floatx("bfloat16")

モデルを読み込む

これで、事前トレーニング済みモデルをダウンロードし、モデルがレスポンスを生成するのに役立つユーティリティ メソッドを作成できます。この手順では、Keras Hub から PaliGemmaCausalLM を使用してモデルをダウンロードします。このクラスは、PaliGemma の因果関係のある視覚言語モデル構造を管理して実行するのに役立ちます。因果関係のある視覚言語モデルは、以前のトークンに基づいて次のトークンを予測します。Keras Hub には、一般的な多くのモデル アーキテクチャの実装が用意されています。

from_preset メソッドを使用してモデルを作成し、その概要を出力します。この処理が完了するまでに 1 分ほどかかります。

paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset("kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224")
paligemma.summary()

ユーティリティ メソッドを作成する

モデルからレスポンスを生成できるように、2 つのユーティリティ メソッドを作成します。

  • crop_and_resize: read_img のヘルパー メソッド。このメソッドは、渡されたサイズに画像を切り抜いてサイズ変更し、画像のアスペクト比を歪めずに最終的な画像のサイズを変更します。
  • read_img: read_img_from_url のヘルパー メソッド。このメソッドは、実際に画像を開き、モデルの制約に収まるようにサイズを変更し、モデルが解釈できる配列に格納します。
  • read_img_from_url: 有効な URL 経由で画像を取得します。このメソッドは、画像をモデルに渡すために必要です。

このノートブックの次のステップで read_img_from_url を使用します。

def crop_and_resize(image, target_size):
    width, height = image.size
    source_size = min(image.size)
    left = width // 2 - source_size // 2
    top = height // 2 - source_size // 2
    right, bottom = left + source_size, top + source_size
    return image.resize(target_size, box=(left, top, right, bottom))

def read_image(url, target_size):
    contents = io.BytesIO(requests.get(url).content)
    image = PIL.Image.open(contents)
    image = crop_and_resize(image, target_size)
    image = np.array(image)
    # Remove alpha channel if necessary.
    if image.shape[2] == 4:
        image = image[:, :, :3]
    return image

def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>.+?)( ;|$)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

def display_boxes(image, boxes, labels, target_image_size):
  h, l = target_size
  fig, ax = plt.subplots()
  ax.imshow(image)
  for i in range(boxes.shape[0]):
      y, x, y2, x2 = (boxes[i]*h)
      width = x2 - x
      height = y2 - y
      # Create a Rectangle patch
      rect = patches.Rectangle((x, y),
                               width,
                               height,
                               linewidth=1,
                               edgecolor='r',
                               facecolor='none')
      # Add label
      plt.text(x, y, labels[i], color='red', fontsize=12)
      # Add the patch to the Axes
      ax.add_patch(rect)

  plt.show()

def display_segment_output(image, bounding_box, segment_mask, target_image_size):
    # Initialize a full mask with the target size
    full_mask = np.zeros(target_image_size, dtype=np.uint8)
    target_width, target_height = target_image_size

    for bbox, mask in zip(bounding_box, segment_mask):
        y1, x1, y2, x2 = bbox
        x1 = int(x1 * target_width)
        y1 = int(y1 * target_height)
        x2 = int(x2 * target_width)
        y2 = int(y2 * target_height)

        # Ensure mask is 2D before converting to Image
        if mask.ndim == 3:
            mask = mask.squeeze(axis=-1)
        mask = Image.fromarray(mask)
        mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)
        mask = np.array(mask)
        binary_mask = (mask > 0.5).astype(np.uint8)


        # Place the binary mask onto the full mask
        full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)
    cmap = plt.get_cmap('jet')
    colored_mask = cmap(full_mask / 1.0)
    colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)
    if isinstance(image, Image.Image):
        image = np.array(image)
    blended_image = image.copy()
    mask_indices = full_mask > 0
    alpha = 0.5

    for c in range(3):
        blended_image[:, :, c] = np.where(mask_indices,
                                          (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],
                                          image[:, :, c])

    fig, ax = plt.subplots()
    ax.imshow(blended_image)
    plt.show()

出力を生成する

モデルを読み込んでユーティリティ メソッドを作成したら、画像データとテキストデータを使用してモデルにプロンプトを送信し、レスポンスを生成できます。PaliGemma モデルは、特定のタスク(answercaptiondetect など)の特定のプロンプト構文でトレーニングされます。PaliGemma プロンプト タスクの構文の詳細については、PaliGemma プロンプトとシステム指示をご覧ください。

次のコードを使用してテスト画像をオブジェクトに読み込み、生成プロンプトで使用できるように画像を準備します。

target_size = (224, 224)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
cow_image = read_image(image_url, target_size)
matplotlib.pyplot.imshow(cow_image)

特定の言語で回答する

次のサンプルコードは、指定された画像に表示されるオブジェクトに関する情報を PaliGemma モデルにプロンプトする方法を示しています。この例では、answer {lang} 構文を使用して、他の言語で追加の質問を表示します。

prompt = 'answer en where is the cow standing?\n'
# prompt = 'svar no hvor står kuen?\n'
# prompt = 'answer fr quelle couleur est le ciel?\n'
# prompt = 'responda pt qual a cor do animal?\n'

output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
print(output)

detect プロンプトを使用する

次のサンプルコードでは、detect プロンプト構文を使用して、指定された画像内のオブジェクトを特定します。このコードは、前に定義した parse_bbox_and_labels() 関数と display_boxes() 関数を使用して、モデルの出力を解釈し、生成されたバウンディング ボックスを表示します。

prompt = 'detect cow\n'
output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
boxes, labels = parse_bbox_and_labels(output)
display_boxes(cow_image, boxes, labels, target_size)

segment プロンプトを使用する

次のコードサンプルでは、segment プロンプト構文を使用して、オブジェクトが占有する画像領域を特定します。Google big_vision ライブラリを使用してモデル出力を解釈し、セグメント化されたオブジェクトのマスクを生成します。

始める前に、次のコードサンプルに示すように、big_vision ライブラリとその依存関係をインストールします。

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")


# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

このセグメンテーションの例では、猫を含む別の画像を読み込んで準備します。

cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)
matplotlib.pyplot.imshow(cat)

以下は、PaliGemma のセグメント出力を解析する関数です。

import  big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      + ''.join(f'<seg(?P<s{i}>\d\d\d)>' for i in range(16)),
      detokenized_output,
  )
  boxes, segs = [], []
  fmt_box = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])
    segs.append([int(d[f's{i}']) for i in range(16)])
  return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))

PaliGemma にクエリを実行して、画像内の猫をセグメント化します

prompt = 'segment cat\n'
output = paligemma.generate(
    inputs={
        "images": cat,
        "prompts": prompt,
    }
)

PaliGemma から生成されたマスクを可視化する

bboxes, seg_masks = parse_segments(output)
display_segment_output(cat, bboxes, seg_masks, target_size)

バッチ プロンプト

1 つのプロンプト内に複数のプロンプト コマンドを指示のバッチとして指定できます。次の例は、複数の手順を提供するプロンプト テキストの構造を示しています。

prompts = [
    'answer en where is the cow standing?\n',
    'answer en what color is the cow?\n',
    'describe en\n',
    'detect cow\n',
    'segment cow\n',
]
images = [cow_image, cow_image, cow_image, cow_image, cow_image]
outputs = paligemma.generate(
    inputs={
        "images": images,
        "prompts": prompts,
    }
)
for output in outputs:
    print(output)