สร้างเอาต์พุต PaliGemma ด้วย Keras

โมเดล PaliGemma มีความสามารถหลายรูปแบบ ซึ่งช่วยให้คุณสร้างเอาต์พุตได้โดยใช้ทั้งข้อมูลอินพุตที่เป็นข้อความและรูปภาพ คุณสามารถใช้ข้อมูลรูปภาพกับโมเดลเหล่านี้เพื่อระบุบริบทเพิ่มเติมสำหรับคำขอ หรือใช้โมเดลเพื่อวิเคราะห์เนื้อหาของรูปภาพได้ บทแนะนำนี้จะแสดงวิธีใช้ PaliGemma กับ Keras เพื่อวิเคราะห์รูปภาพและตอบคำถามเกี่ยวกับรูปภาพ

มีอะไรอยู่ในสมุดโน้ตนี้

โน้ตบุ๊กนี้ใช้ PaliGemma กับ Keras และแสดงวิธีต่อไปนี้

  • ติดตั้ง Keras และทรัพยากร Dependency ที่จําเป็น
  • ดาวน์โหลด PaliGemmaCausalLM ซึ่งเป็นตัวแปร PaliGemma ที่ผ่านการฝึกอบรมไว้ล่วงหน้าสําหรับการประมาณภาษาภาพที่เกิดจากสาเหตุ และนําไปใช้สร้างโมเดล
  • ทดสอบความสามารถของโมเดลในการอนุมานข้อมูลเกี่ยวกับรูปภาพที่ให้ไว้

ก่อนเริ่มต้น

ก่อนอ่านโน้ตบุ๊กนี้ คุณควรคุ้นเคยกับโค้ด Python รวมถึงวิธีฝึกโมเดลภาษาขนาดใหญ่ (LLM) คุณไม่จำเป็นต้องคุ้นเคยกับ Keras แต่ความรู้พื้นฐานเกี่ยวกับ Keras จะมีประโยชน์เมื่ออ่านโค้ดตัวอย่าง

ตั้งค่า

ส่วนต่อไปนี้อธิบายขั้นตอนเบื้องต้นในการทำให้โน้ตบุ๊กใช้โมเดล PaliGemma ซึ่งรวมถึงการเข้าถึงโมเดล การรับคีย์ API และการกําหนดค่ารันไทม์ของโน้ตบุ๊ก

รับสิทธิ์เข้าถึง PaliGemma

ก่อนใช้ PaliGemma เป็นครั้งแรก คุณต้องขอสิทธิ์เข้าถึงโมเดลผ่าน Kaggle โดยทําตามขั้นตอนต่อไปนี้

  1. เข้าสู่ระบบ Kaggle หรือสร้างบัญชี Kaggle ใหม่หากยังไม่มี
  2. ไปที่การ์ดโมเดล PaliGemma แล้วคลิกขอสิทธิ์เข้าถึง
  3. กรอกแบบฟอร์มความยินยอมและยอมรับข้อกำหนดและเงื่อนไข

กำหนดค่าคีย์ API

หากต้องการใช้ PaliGemma คุณต้องระบุชื่อผู้ใช้ Kaggle และคีย์ API ของ Kaggle

หากต้องการสร้างคีย์ API ของ Kaggle ให้เปิดหน้าการตั้งค่าใน Kaggle แล้วคลิกสร้างโทเค็นใหม่ ซึ่งจะทริกเกอร์การดาวน์โหลดไฟล์ kaggle.json ที่มีข้อมูลเข้าสู่ระบบ API

จากนั้นใน Colab ให้เลือกข้อมูลลับ (🔑) ในแผงด้านซ้าย แล้วเพิ่มชื่อผู้ใช้ Kaggle และคีย์ Kaggle API จัดเก็บชื่อผู้ใช้โดยใช้ชื่อ KAGGLE_USERNAME และคีย์ API โดยใช้ชื่อ KAGGLE_KEY

เลือกรันไทม์

คุณจะต้องมีรันไทม์ Colab ที่มีทรัพยากรเพียงพอที่จะเรียกใช้โมเดล PaliGemma จึงจะทำตามบทแนะนำนี้ให้เสร็จสมบูรณ์ได้ ในกรณีนี้ คุณสามารถใช้ GPU T4 ได้

  1. ที่ด้านขวาบนของหน้าต่าง Colab ให้คลิกเมนูแบบเลื่อนลง ▾ (ตัวเลือกการเชื่อมต่อเพิ่มเติม)
  2. เลือกเปลี่ยนประเภทรันไทม์
  3. เลือก GPU T4 ในส่วนตัวเร่งฮาร์ดแวร์

ตั้งค่าตัวแปรสภาพแวดล้อม

ตั้งค่าตัวแปรสภาพแวดล้อมสําหรับ KAGGLE_USERNAME, KAGGLE_KEY และ KERAS_BACKEND

import os
from google.colab import userdata

# Set up environmental variables
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"

ติดตั้ง Keras

เรียกใช้เซลล์ด้านล่างเพื่อติดตั้ง Keras

pip install -U -q keras-nlp keras-hub kagglehub

นําเข้า Dependency และกำหนดค่า Keras

ติดตั้ง Dependency ที่จําเป็นสําหรับโน้ตบุ๊กนี้และกําหนดค่าแบ็กเอนด์ของ Keras นอกจากนี้ คุณยังตั้งค่าให้ Keras ใช้ bfloat16 เพื่อให้เฟรมเวิร์กใช้หน่วยความจำน้อยลงด้วย

import keras
import keras_hub
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

keras.config.set_floatx("bfloat16")

โหลดโมเดล

เมื่อตั้งค่าทุกอย่างแล้ว คุณสามารถดาวน์โหลดโมเดลที่ผ่านการฝึกอบรมล่วงหน้าและสร้างเมธอดยูทิลิตีบางอย่างเพื่อช่วยโมเดลในการสร้างคำตอบ ในขั้นตอนนี้ คุณจะดาวน์โหลดโมเดลโดยใช้ PaliGemmaCausalLM จาก Keras Hub คลาสนี้ช่วยคุณจัดการและเรียกใช้โครงสร้างโมเดลภาษาภาพเชิงสาเหตุของ PaliGemma โมเดลภาษาภาพที่เกิดจากสาเหตุจะคาดการณ์โทเค็นถัดไปโดยอิงตามโทเค็นก่อนหน้า Keras Hub มีการใช้งานสถาปัตยกรรมโมเดลยอดนิยมหลายรายการ

สร้างโมเดลโดยใช้เมธอด from_preset และพิมพ์ข้อมูลสรุป การดำเนินการนี้จะใช้เวลาประมาณ 1 นาที

paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset("kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224")
paligemma.summary()

สร้างเมธอดยูทิลิตี

สร้างเมธอดยูทิลิตี 2 รายการเพื่อช่วยสร้างคำตอบจากโมเดล

  • crop_and_resize: เมธอดตัวช่วยสําหรับ read_img วิธีนี้จะครอบตัดและปรับขนาดรูปภาพเป็นขนาดที่ส่งมาเพื่อให้รูปภาพสุดท้ายมีการปรับขนาดโดยไม่บิดเบือนสัดส่วนของรูปภาพ
  • read_img: เมธอดตัวช่วยสําหรับ read_img_from_url วิธีการนี้จะเปิดรูปภาพ ปรับขนาดรูปภาพให้พอดีกับข้อจำกัดของโมเดล และใส่รูปภาพลงในอาร์เรย์ที่โมเดลตีความได้
  • read_img_from_url: รับรูปภาพผ่าน URL ที่ถูกต้อง คุณต้องใช้เมธอดนี้เพื่อส่งรูปภาพไปยังโมเดล

คุณจะใช้ read_img_from_url ในขั้นตอนถัดไปของสมุดบันทึกนี้

def crop_and_resize(image, target_size):
    width, height = image.size
    source_size = min(image.size)
    left = width // 2 - source_size // 2
    top = height // 2 - source_size // 2
    right, bottom = left + source_size, top + source_size
    return image.resize(target_size, box=(left, top, right, bottom))

def read_image(url, target_size):
    contents = io.BytesIO(requests.get(url).content)
    image = PIL.Image.open(contents)
    image = crop_and_resize(image, target_size)
    image = np.array(image)
    # Remove alpha channel if necessary.
    if image.shape[2] == 4:
        image = image[:, :, :3]
    return image

def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>.+?)( ;|$)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

def display_boxes(image, boxes, labels, target_image_size):
  h, l = target_size
  fig, ax = plt.subplots()
  ax.imshow(image)
  for i in range(boxes.shape[0]):
      y, x, y2, x2 = (boxes[i]*h)
      width = x2 - x
      height = y2 - y
      # Create a Rectangle patch
      rect = patches.Rectangle((x, y),
                               width,
                               height,
                               linewidth=1,
                               edgecolor='r',
                               facecolor='none')
      # Add label
      plt.text(x, y, labels[i], color='red', fontsize=12)
      # Add the patch to the Axes
      ax.add_patch(rect)

  plt.show()

def display_segment_output(image, bounding_box, segment_mask, target_image_size):
    # Initialize a full mask with the target size
    full_mask = np.zeros(target_image_size, dtype=np.uint8)
    target_width, target_height = target_image_size

    for bbox, mask in zip(bounding_box, segment_mask):
        y1, x1, y2, x2 = bbox
        x1 = int(x1 * target_width)
        y1 = int(y1 * target_height)
        x2 = int(x2 * target_width)
        y2 = int(y2 * target_height)

        # Ensure mask is 2D before converting to Image
        if mask.ndim == 3:
            mask = mask.squeeze(axis=-1)
        mask = Image.fromarray(mask)
        mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)
        mask = np.array(mask)
        binary_mask = (mask > 0.5).astype(np.uint8)


        # Place the binary mask onto the full mask
        full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)
    cmap = plt.get_cmap('jet')
    colored_mask = cmap(full_mask / 1.0)
    colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)
    if isinstance(image, Image.Image):
        image = np.array(image)
    blended_image = image.copy()
    mask_indices = full_mask > 0
    alpha = 0.5

    for c in range(3):
        blended_image[:, :, c] = np.where(mask_indices,
                                          (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],
                                          image[:, :, c])

    fig, ax = plt.subplots()
    ax.imshow(blended_image)
    plt.show()

สร้างเอาต์พุต

หลังจากโหลดโมเดลและสร้างเมธอดยูทิลิตีแล้ว คุณสามารถส่งพรอมต์รูปภาพและข้อมูลข้อความไปยังโมเดลเพื่อสร้างคำตอบได้ โมเดล PaliGemma ได้รับการฝึกด้วยไวยากรณ์พรอมต์ที่เฉพาะเจาะจงสำหรับงานบางอย่าง เช่น answer, caption และ detect ดูข้อมูลเพิ่มเติมเกี่ยวกับไวยากรณ์ของงานพรอมต์ PaliGemma ได้ที่พรอมต์ PaliGemma และวิธีการของระบบ

เตรียมรูปภาพเพื่อใช้ในพรอมต์การสร้างโดยใช้โค้ดต่อไปนี้เพื่อโหลดรูปภาพทดสอบลงในออบเจ็กต์

target_size = (224, 224)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
cow_image = read_image(image_url, target_size)
matplotlib.pyplot.imshow(cow_image)

ตอบเป็นภาษาที่เฉพาะเจาะจง

โค้ดตัวอย่างต่อไปนี้แสดงวิธีแจ้งให้โมเดล PaliGemma แสดงข้อมูลเกี่ยวกับวัตถุที่ปรากฏในรูปภาพที่ระบุ ตัวอย่างนี้ใช้ไวยากรณ์ answer {lang} และแสดงคำถามเพิ่มเติมในภาษาอื่นๆ

prompt = 'answer en where is the cow standing?\n'
# prompt = 'svar no hvor står kuen?\n'
# prompt = 'answer fr quelle couleur est le ciel?\n'
# prompt = 'responda pt qual a cor do animal?\n'

output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
print(output)

ใช้พรอมต์ detect

ตัวอย่างโค้ดต่อไปนี้ใช้ไวยากรณ์พรอมต์ detect เพื่อค้นหาวัตถุในรูปภาพที่ระบุ โค้ดใช้ฟังก์ชัน parse_bbox_and_labels() และ display_boxes() ที่กําหนดไว้ก่อนหน้านี้เพื่อตีความเอาต์พุตของโมเดลและแสดงกล่องขอบเขตที่สร้างขึ้น

prompt = 'detect cow\n'
output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
boxes, labels = parse_bbox_and_labels(output)
display_boxes(cow_image, boxes, labels, target_size)

ใช้พรอมต์ segment

โค้ดตัวอย่างต่อไปนี้ใช้ไวยากรณ์พรอมต์ segment เพื่อค้นหาพื้นที่ของรูปภาพที่มีวัตถุ โดยจะใช้ไลบรารี big_vision ของ Google เพื่อตีความเอาต์พุตของโมเดลและสร้างมาสก์สําหรับวัตถุที่แบ่งออกเป็นส่วนๆ

ก่อนเริ่มต้น ให้ติดตั้งไลบรารี big_vision และไลบรารีที่ต้องพึ่งพา ดังที่แสดงในตัวอย่างโค้ดนี้

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")


# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

สําหรับตัวอย่างการแบ่งกลุ่มนี้ ให้โหลดและเตรียมรูปภาพอื่นที่มีแมว

cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)
matplotlib.pyplot.imshow(cat)

ฟังก์ชันที่จะช่วยแยกวิเคราะห์เอาต์พุตกลุ่มจาก PaliGemma มีดังนี้

import  big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      + ''.join(f'<seg(?P<s{i}>\d\d\d)>' for i in range(16)),
      detokenized_output,
  )
  boxes, segs = [], []
  fmt_box = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])
    segs.append([int(d[f's{i}']) for i in range(16)])
  return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))

ค้นหา PaliGemma เพื่อแบ่งกลุ่มแมวในรูปภาพ

prompt = 'segment cat\n'
output = paligemma.generate(
    inputs={
        "images": cat,
        "prompts": prompt,
    }
)

แสดงภาพมาสก์ที่สร้างขึ้นจาก PaliGemma

bboxes, seg_masks = parse_segments(output)
display_segment_output(cat, bboxes, seg_masks, target_size)

พรอมต์แบบเป็นกลุ่ม

คุณสามารถระบุคำสั่งพรอมต์ได้มากกว่า 1 รายการภายในพรอมต์เดียวเป็นชุดคำสั่ง ตัวอย่างต่อไปนี้แสดงวิธีจัดโครงสร้างข้อความพรอมต์เพื่อให้คําแนะนําหลายรายการ

prompts = [
    'answer en where is the cow standing?\n',
    'answer en what color is the cow?\n',
    'describe en\n',
    'detect cow\n',
    'segment cow\n',
]
images = [cow_image, cow_image, cow_image, cow_image, cow_image]
outputs = paligemma.generate(
    inputs={
        "images": images,
        "prompts": prompts,
    }
)
for output in outputs:
    print(output)