יצירת פלט של PaliGemma באמצעות Keras

לצפייה בכתובת ai.google.dev הרצה ב-Google Colab פתיחה ב-Vertex AI צפייה במקור ב-GitHub

למודלים של PaliGemma יש יכולות מולטי-מודאליות, שמאפשרות ליצור פלט באמצעות נתוני קלט של טקסט ותמונות. אתם יכולים להשתמש בנתוני תמונות עם המודלים האלה כדי לספק הקשר נוסף לבקשות שלכם, או להשתמש במודל כדי לנתח את התוכן של תמונות. במדריך הזה נסביר איך להשתמש ב-PaliGemma עם Keras כדי לנתח תמונות ולענות על שאלות לגביהן.

מה כולל ה-notebook הזה

במחברת הזו נעשה שימוש ב-PaliGemma עם Keras, והיא מראה לכם איך:

  • התקנת Keras ויחסי התלות הנדרשים
  • מורידים את PaliGemmaCausalLM, וריאציה של PaliGemma שעברה אימון מראש ליצירת מודלים של שפה חזותית סיבתית, ומשתמשים בה כדי ליצור מודל
  • בדיקת היכולת של המודל להסיק מידע על תמונות שסופקו

לפני שמתחילים

לפני שמתחילים לעבוד עם המחברת הזו, כדאי להכיר את קוד Python ואת תהליך האימון של מודלים גדולים של שפה (LLM). לא צריך להכיר את Keras, אבל ידע בסיסי ב-Keras יעזור לכם לקרוא את קוד הדוגמה.

הגדרה

בקטעים הבאים מוסבר על השלבים המקדימים להפעלת מחברת לשימוש במודל PaliGemma, כולל גישה למודל, קבלת מפתח API והגדרת זמן הריצה של המחברת.

קבלת גישה ל-PaliGemma

לפני שמשתמשים ב-PaliGemma בפעם הראשונה, צריך לבקש גישה למודל דרך Kaggle. לשם כך, מבצעים את השלבים הבאים:

  1. נכנסים ל-Kaggle או יוצרים חשבון חדש ב-Kaggle אם עדיין אין לכם חשבון.
  2. עוברים אל כרטיס הדגם של PaliGemma ולוחצים על בקשת גישה.
  3. ממלאים את טופס ההסכמה ומאשרים את התנאים וההגבלות.

הגדרת מפתח API

כדי להשתמש ב-PaliGemma, צריך לספק את שם המשתמש שלכם ב-Kaggle ומפתח API של Kaggle.

כדי ליצור מפתח Kaggle API, פותחים את דף ההגדרות ב-Kaggle ולוחצים על Create New Token (יצירת אסימון חדש). הפעולה הזו תפעיל הורדה של קובץ kaggle.json שמכיל את פרטי הכניסה שלכם ל-API.

לאחר מכן, ב-Colab, בוחרים באפשרות Secrets (סודות) (🔑) בחלונית הימנית ומוסיפים את שם המשתמש שלכם ב-Kaggle ואת מפתח ה-API של Kaggle. מאחסנים את שם המשתמש בשם KAGGLE_USERNAME ואת מפתח ה-API בשם KAGGLE_KEY.

בחירת זמן הריצה

כדי להשלים את המדריך הזה, צריך סביבת ריצה של Colab עם מספיק משאבים להרצת מודל PaliGemma. במקרה כזה, אפשר להשתמש ב-GPU מסוג T4:

  1. בפינה השמאלית העליונה של חלון Colab, לוחצים על התפריט הנפתח ▾ (אפשרויות נוספות לחיבור).
  2. בוחרים באפשרות שינוי הסוג של סביבת זמן הריצה.
  3. בקטע Hardware accelerator (שיפור המהירות באמצעות חומרה), בוחרים באפשרות T4 GPU.

הגדרה של משתני סביבה

מגדירים את משתני הסביבה KAGGLE_USERNAME,‏ KAGGLE_KEY ו-KERAS_BACKEND.

import os
from google.colab import userdata

# Set up environmental variables
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"

התקנה של Keras

מריצים את התא שלמטה כדי להתקין את Keras.

pip install -U -q keras-nlp keras-hub kagglehub

ייבוא פריטים בקשרי תלות והגדרת Keras

מתקינים את יחסי התלות שנדרשים למחברת הזו ומגדירים את הקצה העורפי של Keras. בנוסף, מגדירים את Keras להשתמש ב-bfloat16 כדי שה-framework ישתמש בפחות זיכרון.

import keras
import keras_hub
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

keras.config.set_floatx("bfloat16")

טעינת המודל

אחרי שמסיימים את ההגדרה, אפשר להוריד את המודל שאומן מראש וליצור כמה שיטות עזר שיעזרו למודל ליצור את התשובות שלו. בשלב הזה, מורידים מודל באמצעות PaliGemmaCausalLM מ-Keras Hub. המחלקות האלה עוזרות לכם לנהל ולהפעיל את המבנה של מודל השפה החזותי הסיבתי של PaliGemma. מודל שפה חזותי סיבתי חוזה את האסימון הבא על סמך אסימונים קודמים. ב-Keras Hub יש יישומים של הרבה ארכיטקטורות מודלים פופולריות.

יוצרים את המודל באמצעות השיטה from_preset ומדפיסים את הסיכום שלו. התהליך יימשך כדקה.

paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset("kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224")
paligemma.summary()

יצירת שיטות עזר

כדי לעזור לכם ליצור תשובות מהמודל, אתם יכולים ליצור שתי שיטות עזר:

  • crop_and_resize: שיטת עזר ל-read_img. בשיטה הזו התמונה נחתכת ומשנה את הגודל שלה בהתאם לגודל שמועבר, כך שהגודל של התמונה הסופית משתנה בלי שהפרופורציות שלה ישתנו.
  • read_img: שיטת עזר ל-read_img_from_url. השיטה הזו היא מה שפותח את התמונה, משנה את הגודל שלה כך שתתאים למגבלות של המודל ומכניס אותה למערך שאפשר לפרש באמצעות המודל.
  • read_img_from_url: מקבלת תמונה באמצעות כתובת URL תקינה. צריך להשתמש בשיטה הזו כדי להעביר את התמונה למודל.

תשתמשו ב-read_img_from_url בשלב הבא של ה-notebook הזה.

def crop_and_resize(image, target_size):
    width, height = image.size
    source_size = min(image.size)
    left = width // 2 - source_size // 2
    top = height // 2 - source_size // 2
    right, bottom = left + source_size, top + source_size
    return image.resize(target_size, box=(left, top, right, bottom))

def read_image(url, target_size):
    contents = io.BytesIO(requests.get(url).content)
    image = PIL.Image.open(contents)
    image = crop_and_resize(image, target_size)
    image = np.array(image)
    # Remove alpha channel if necessary.
    if image.shape[2] == 4:
        image = image[:, :, :3]
    return image

def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>.+?)( ;|$)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

def display_boxes(image, boxes, labels, target_image_size):
  h, l = target_size
  fig, ax = plt.subplots()
  ax.imshow(image)
  for i in range(boxes.shape[0]):
      y, x, y2, x2 = (boxes[i]*h)
      width = x2 - x
      height = y2 - y
      # Create a Rectangle patch
      rect = patches.Rectangle((x, y),
                               width,
                               height,
                               linewidth=1,
                               edgecolor='r',
                               facecolor='none')
      # Add label
      plt.text(x, y, labels[i], color='red', fontsize=12)
      # Add the patch to the Axes
      ax.add_patch(rect)

  plt.show()

def display_segment_output(image, bounding_box, segment_mask, target_image_size):
    # Initialize a full mask with the target size
    full_mask = np.zeros(target_image_size, dtype=np.uint8)
    target_width, target_height = target_image_size

    for bbox, mask in zip(bounding_box, segment_mask):
        y1, x1, y2, x2 = bbox
        x1 = int(x1 * target_width)
        y1 = int(y1 * target_height)
        x2 = int(x2 * target_width)
        y2 = int(y2 * target_height)

        # Ensure mask is 2D before converting to Image
        if mask.ndim == 3:
            mask = mask.squeeze(axis=-1)
        mask = Image.fromarray(mask)
        mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)
        mask = np.array(mask)
        binary_mask = (mask > 0.5).astype(np.uint8)


        # Place the binary mask onto the full mask
        full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)
    cmap = plt.get_cmap('jet')
    colored_mask = cmap(full_mask / 1.0)
    colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)
    if isinstance(image, Image.Image):
        image = np.array(image)
    blended_image = image.copy()
    mask_indices = full_mask > 0
    alpha = 0.5

    for c in range(3):
        blended_image[:, :, c] = np.where(mask_indices,
                                          (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],
                                          image[:, :, c])

    fig, ax = plt.subplots()
    ax.imshow(blended_image)
    plt.show()

יצירת פלט

אחרי טעינת המודל ויצירת שיטות עזר, אפשר להזין למודל הנחיות עם נתוני תמונה וטקסט כדי ליצור תשובות. מודלים של PaliGemma מאומנים עם תחביר ספציפי של הנחיות למשימות ספציפיות, כמו answer, caption ו-detect. מידע נוסף על תחביר של משימות בהנחיות ל-PaliGemma זמין במאמר הנחיות ומערכת הוראות ל-PaliGemma.

כדי להכין תמונה לשימוש בהנחיה ליצירה, משתמשים בקוד הבא כדי לטעון תמונה לבדיקה לאובייקט:

target_size = (224, 224)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
cow_image = read_image(image_url, target_size)
matplotlib.pyplot.imshow(cow_image)

תשובה בשפה מסוימת

בדוגמת הקוד הבאה אפשר לראות איך מבקשים ממודל PaliGemma מידע על אובייקט שמופיע בתמונה שסופקה. בדוגמה הזו נעשה שימוש בתחביר answer {lang} ומוצגות שאלות נוספות בשפות אחרות:

prompt = 'answer en where is the cow standing?\n'
# prompt = 'svar no hvor står kuen?\n'
# prompt = 'answer fr quelle couleur est le ciel?\n'
# prompt = 'responda pt qual a cor do animal?\n'

output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
print(output)

שימוש בהנחיה של detect

בדוגמת הקוד הבאה נעשה שימוש בתחביר של הנחיה detect כדי לאתר אובייקט בתמונה שסופקה. הקוד משתמש בפונקציות parse_bbox_and_labels() ו-display_boxes() שהוגדרו קודם כדי לפרש את הפלט של המודל ולהציג את תיבות התוחמות שנוצרו.

prompt = 'detect cow\n'
output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
boxes, labels = parse_bbox_and_labels(output)
display_boxes(cow_image, boxes, labels, target_size)

שימוש בהנחיה של segment

בדוגמת הקוד הבאה נעשה שימוש בתחביר של הנחיית segment כדי לאתר את האזור בתמונה שבו נמצא אובייקט. היא משתמשת בספריית Google big_vision כדי לפרש את פלט המודל וליצור מסכה לאובייקט המפולח.

לפני שמתחילים, צריך להתקין את ספריית big_vision ואת התלויות שלה, כמו שמוצג בדוגמת הקוד הזו:

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")


# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

בדוגמה הזו של פילוח, טוענים ומכינים תמונה אחרת שכוללת חתול.

cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)
matplotlib.pyplot.imshow(cat)

הנה פונקציה שיכולה לעזור בניתוח הפלט של הפלח מ-PaliGemma

import  big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      + ''.join(f'<seg(?P<s{i}>\d\d\d)>' for i in range(16)),
      detokenized_output,
  )
  boxes, segs = [], []
  fmt_box = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])
    segs.append([int(d[f's{i}']) for i in range(16)])
  return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))

שליחת שאילתה ל-PaliGemma כדי לבצע פילוח של החתול בתמונה

prompt = 'segment cat\n'
output = paligemma.generate(
    inputs={
        "images": cat,
        "prompts": prompt,
    }
)

הדמיה של המסכה שנוצרה מ-PaliGemma

bboxes, seg_masks = parse_segments(output)
display_segment_output(cat, bboxes, seg_masks, target_size)

הנחיות באצווה

אפשר לספק יותר מפקודה אחת להנחיה בהנחיה אחת, כקבוצה של הוראות. בדוגמה הבאה אפשר לראות איך כותבים הנחיה עם כמה הוראות.

prompts = [
    'answer en where is the cow standing?\n',
    'answer en what color is the cow?\n',
    'describe en\n',
    'detect cow\n',
    'segment cow\n',
]
images = [cow_image, cow_image, cow_image, cow_image, cow_image]
outputs = paligemma.generate(
    inputs={
        "images": images,
        "prompts": prompts,
    }
)
for output in outputs:
    print(output)