Keras로 PaliGemma 출력 생성

ai.google.dev에서 보기 Google Colab에서 실행 Vertex AI에서 열기 GitHub에서 소스 보기

PaliGemma 모델에는 멀티모달 기능이 있어 텍스트와 이미지 입력 데이터를 모두 사용하여 출력을 생성할 수 있습니다. 이러한 모델과 함께 이미지 데이터를 사용하여 요청에 추가 컨텍스트를 제공하거나 모델을 사용하여 이미지의 콘텐츠를 분석할 수 있습니다. 이 튜토리얼에서는 Keras와 함께 PaliGemma를 사용하여 이미지를 분석하고 이미지에 관한 질문에 답변하는 방법을 보여줍니다.

이 노트북의 내용

이 노트북은 Keras와 함께 PaliGemma를 사용하며 다음 방법을 보여줍니다.

  • Keras 및 필수 종속 항목 설치
  • 사전 학습된 인과적 시각 언어 모델링용 PaliGemma 변형인 PaliGemmaCausalLM를 다운로드하고 이를 사용하여 모델을 만듭니다.
  • 제공된 이미지에 관한 정보를 추론하는 모델의 기능 테스트

시작하기 전에

이 노트북을 살펴보기 전에 Python 코드와 대규모 언어 모델 (LLM)의 학습 방법을 숙지해야 합니다. Keras에 익숙하지 않아도 되지만 예시 코드를 읽을 때 Keras에 관한 기본 지식이 있으면 도움이 됩니다.

설정

다음 섹션에서는 모델 액세스, API 키 가져오기, 노트북 런타임 구성 등 PaliGemma 모델을 사용하도록 노트북을 설정하기 위한 사전 단계를 설명합니다.

PaliGemma 액세스 권한 획득

PaliGemma를 처음 사용하기 전에 다음 단계를 완료하여 Kaggle을 통해 모델에 대한 액세스 권한을 요청해야 합니다.

  1. Kaggle에 로그인하거나 Kaggle 계정이 없는 경우 새 계정을 만듭니다.
  2. PaliGemma 모델 카드로 이동하여 액세스 요청을 클릭합니다.
  3. 동의 양식을 작성하고 이용약관에 동의합니다.

API 키 구성

PaliGemma를 사용하려면 Kaggle 사용자 이름과 Kaggle API 키를 제공해야 합니다.

Kaggle API 키를 생성하려면 Kaggle에서 설정 페이지를 열고 새 토큰 만들기를 클릭합니다. 이렇게 하면 API 사용자 인증 정보가 포함된 kaggle.json 파일의 다운로드가 트리거됩니다.

그런 다음 Colab의 왼쪽 창에서 보안 비밀 (🔑)을 선택하고 Kaggle 사용자 이름과 Kaggle API 키를 추가합니다. 사용자 이름을 KAGGLE_USERNAME 이름으로, API 키를 KAGGLE_KEY 이름으로 저장합니다.

런타임 선택

이 튜토리얼을 완료하려면 PaliGemma 모델을 실행할 수 있는 충분한 리소스가 있는 Colab 런타임이 필요합니다. 이 경우 T4 GPU를 사용할 수 있습니다.

  1. Colab 창의 오른쪽 상단에서 ▾ (추가 연결 옵션) 드롭다운 메뉴를 클릭합니다.
  2. 런타임 유형 변경을 선택합니다.
  3. 하드웨어 가속기에서 T4 GPU를 선택합니다.

환경 변수 설정하기

KAGGLE_USERNAME, KAGGLE_KEY, KERAS_BACKEND의 환경 변수를 설정합니다.

import os
from google.colab import userdata

# Set up environmental variables
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"

Keras 설치

아래 셀을 실행하여 Keras를 설치합니다.

pip install -U -q keras-nlp keras-hub kagglehub

종속 항목 가져오기 및 Keras 구성

이 노트북에 필요한 종속 항목을 설치하고 Keras 백엔드를 구성합니다. 프레임워크에서 메모리를 적게 사용하도록 Keras가 bfloat16를 사용하도록 설정합니다.

import keras
import keras_hub
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

keras.config.set_floatx("bfloat16")

모델 로드

이제 모든 항목을 설정했으므로 사전 학습된 모델을 다운로드하고 모델이 응답을 생성하는 데 도움이 되는 유틸리티 메서드를 만들 수 있습니다. 이 단계에서는 Keras Hub에서 PaliGemmaCausalLM를 사용하여 모델을 다운로드합니다. 이 클래스는 PaliGemma의 인과관계 시각 언어 모델 구조를 관리하고 실행하는 데 도움이 됩니다. 인과 관계 시각 언어 모델은 이전 토큰을 기반으로 다음 토큰을 예측합니다. Keras Hub는 널리 사용되는 다양한 모델 아키텍처의 구현을 제공합니다.

from_preset 메서드를 사용하여 모델을 만들고 요약을 출력합니다. 이 프로세스를 완료하는 데 약 1분이 소요됩니다.

paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset("kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224")
paligemma.summary()

유틸리티 메서드 만들기

모델에서 응답을 생성하는 데 도움이 되도록 다음 두 가지 유틸리티 메서드를 만듭니다.

  • crop_and_resize: read_img의 도우미 메서드입니다. 이 메서드는 전달된 크기에 맞게 이미지를 자르고 크기를 조절하므로 최종 이미지는 이미지의 비율이 기울어지지 않고 크기가 조절됩니다.
  • read_img: read_img_from_url의 도우미 메서드입니다. 이 메서드는 실제로 이미지를 열고 모델의 제약 조건에 맞게 크기를 조절하고 모델에서 해석할 수 있는 배열에 넣습니다.
  • read_img_from_url: 유효한 URL을 통해 이미지를 가져옵니다. 이 메서드는 이미지를 모델에 전달하는 데 필요합니다.

이 노트북의 다음 단계에서 read_img_from_url를 사용합니다.

def crop_and_resize(image, target_size):
    width, height = image.size
    source_size = min(image.size)
    left = width // 2 - source_size // 2
    top = height // 2 - source_size // 2
    right, bottom = left + source_size, top + source_size
    return image.resize(target_size, box=(left, top, right, bottom))

def read_image(url, target_size):
    contents = io.BytesIO(requests.get(url).content)
    image = PIL.Image.open(contents)
    image = crop_and_resize(image, target_size)
    image = np.array(image)
    # Remove alpha channel if necessary.
    if image.shape[2] == 4:
        image = image[:, :, :3]
    return image

def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>.+?)( ;|$)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

def display_boxes(image, boxes, labels, target_image_size):
  h, l = target_size
  fig, ax = plt.subplots()
  ax.imshow(image)
  for i in range(boxes.shape[0]):
      y, x, y2, x2 = (boxes[i]*h)
      width = x2 - x
      height = y2 - y
      # Create a Rectangle patch
      rect = patches.Rectangle((x, y),
                               width,
                               height,
                               linewidth=1,
                               edgecolor='r',
                               facecolor='none')
      # Add label
      plt.text(x, y, labels[i], color='red', fontsize=12)
      # Add the patch to the Axes
      ax.add_patch(rect)

  plt.show()

def display_segment_output(image, bounding_box, segment_mask, target_image_size):
    # Initialize a full mask with the target size
    full_mask = np.zeros(target_image_size, dtype=np.uint8)
    target_width, target_height = target_image_size

    for bbox, mask in zip(bounding_box, segment_mask):
        y1, x1, y2, x2 = bbox
        x1 = int(x1 * target_width)
        y1 = int(y1 * target_height)
        x2 = int(x2 * target_width)
        y2 = int(y2 * target_height)

        # Ensure mask is 2D before converting to Image
        if mask.ndim == 3:
            mask = mask.squeeze(axis=-1)
        mask = Image.fromarray(mask)
        mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)
        mask = np.array(mask)
        binary_mask = (mask > 0.5).astype(np.uint8)


        # Place the binary mask onto the full mask
        full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)
    cmap = plt.get_cmap('jet')
    colored_mask = cmap(full_mask / 1.0)
    colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)
    if isinstance(image, Image.Image):
        image = np.array(image)
    blended_image = image.copy()
    mask_indices = full_mask > 0
    alpha = 0.5

    for c in range(3):
        blended_image[:, :, c] = np.where(mask_indices,
                                          (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],
                                          image[:, :, c])

    fig, ax = plt.subplots()
    ax.imshow(blended_image)
    plt.show()

출력 생성

모델을 로드하고 유틸리티 메서드를 만든 후 이미지 및 텍스트 데이터로 모델에 프롬프트를 표시하여 대답을 생성할 수 있습니다. PaliGemma 모델은 answer, caption, detect와 같은 특정 작업에 대한 특정 프롬프트 문법으로 학습됩니다. PaliGemma 프롬프트 작업 구문에 대한 자세한 내용은 PaliGemma 프롬프트 및 시스템 요청 사항을 참고하세요.

다음 코드를 사용하여 테스트 이미지를 객체에 로드하여 생성 프롬프트에 사용할 이미지를 준비합니다.

target_size = (224, 224)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
cow_image = read_image(image_url, target_size)
matplotlib.pyplot.imshow(cow_image)

특정 언어로 대답

다음 예시 코드에서는 제공된 이미지에 표시된 객체에 관한 정보를 PaliGemma 모델에 요청하는 방법을 보여줍니다. 이 예에서는 answer {lang} 구문을 사용하고 다른 언어로 추가 질문을 표시합니다.

prompt = 'answer en where is the cow standing?\n'
# prompt = 'svar no hvor står kuen?\n'
# prompt = 'answer fr quelle couleur est le ciel?\n'
# prompt = 'responda pt qual a cor do animal?\n'

output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
print(output)

detect 프롬프트 사용

다음 예시 코드에서는 detect 프롬프트 구문을 사용하여 제공된 이미지에서 객체를 찾습니다. 이 코드에서는 이전에 정의된 parse_bbox_and_labels()display_boxes() 함수를 사용하여 모델 출력을 해석하고 생성된 경계 상자를 표시합니다.

prompt = 'detect cow\n'
output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
boxes, labels = parse_bbox_and_labels(output)
display_boxes(cow_image, boxes, labels, target_size)

segment 프롬프트 사용

다음 예시 코드에서는 segment 프롬프트 구문을 사용하여 객체가 차지하는 이미지 영역을 찾습니다. Google big_vision 라이브러리를 사용하여 모델 출력을 해석하고 분할된 객체의 마스크를 생성합니다.

시작하기 전에 다음 코드 예와 같이 big_vision 라이브러리와 종속 항목을 설치합니다.

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")


# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

이 분할 예시에서는 고양이가 포함된 다른 이미지를 로드하고 준비합니다.

cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)
matplotlib.pyplot.imshow(cat)

다음은 PaliGemma의 세그먼트 출력을 파싱하는 데 도움이 되는 함수입니다.

import  big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      + ''.join(f'<seg(?P<s{i}>\d\d\d)>' for i in range(16)),
      detokenized_output,
  )
  boxes, segs = [], []
  fmt_box = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])
    segs.append([int(d[f's{i}']) for i in range(16)])
  return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))

PaliGemma에 이미지를 보내 고양이를 분할하도록 요청

prompt = 'segment cat\n'
output = paligemma.generate(
    inputs={
        "images": cat,
        "prompts": prompt,
    }
)

PaliGemma에서 생성된 마스크 시각화

bboxes, seg_masks = parse_segments(output)
display_segment_output(cat, bboxes, seg_masks, target_size)

일괄 프롬프트

단일 프롬프트 내에서 여러 프롬프트 명령어를 일괄적으로 제공할 수 있습니다. 다음 예에서는 여러 명령어를 제공하기 위해 프롬프트 텍스트를 구성하는 방법을 보여줍니다.

prompts = [
    'answer en where is the cow standing?\n',
    'answer en what color is the cow?\n',
    'describe en\n',
    'detect cow\n',
    'segment cow\n',
]
images = [cow_image, cow_image, cow_image, cow_image, cow_image]
outputs = paligemma.generate(
    inputs={
        "images": images,
        "prompts": prompts,
    }
)
for output in outputs:
    print(output)