Generare l'output di PaliGemma con Keras

I modelli PaliGemma hanno funzionalità multimodali, che ti consentono di generare output utilizzando sia i dati di input di testo che quelli di immagini. Puoi utilizzare i dati delle immagini con questi modelli per fornire un contesto aggiuntivo alle tue richieste o utilizzare il modello per analizzare i contenuti delle immagini. Questo tutorial mostra come utilizzare PaliGemma con Keras per analizzare le immagini e rispondere a domande su di esse.

Contenuto del blocco note

Questo notebook utilizza PaliGemma con Keras e mostra come:

  • Installa Keras e le dipendenze richieste
  • Scarica PaliGemmaCausalLM, una variante di PaliGemma preaddestrata per la creazione di modelli di linguaggio visivo causale, e utilizzala per creare un modello
  • Verifica la capacità del modello di dedurre informazioni sulle immagini fornite

Prima di iniziare

Prima di esaminare questo notebook, devi conoscere il codice Python e le modalità di addestramento dei modelli linguistici di grandi dimensioni (LLM). Non è necessario conoscere Keras, ma una conoscenza di base di Keras è utile per leggere il codice di esempio.

Configurazione

Le sezioni seguenti spiegano i passaggi preliminari per consentire a un notebook di utilizzare un modello PaliGemma, tra cui l'accesso al modello, l'ottenimento di una chiave API e la configurazione del runtime del notebook.

Accedere a PaliGemma

Prima di utilizzare PaliGemma per la prima volta, devi richiedere l'accesso al modello tramite Kaggle completando i seguenti passaggi:

  1. Accedi a Kaggle o crea un nuovo account Kaggle se non ne hai già uno.
  2. Vai alla scheda del modello PaliGemma e fai clic su Richiedi accesso.
  3. Compila il modulo per il consenso e accetta i Termini e condizioni.

Configura la chiave API

Per utilizzare PaliGemma, devi fornire il tuo nome utente Kaggle e una chiave API Kaggle.

Per generare una chiave API Kaggle, apri la pagina Impostazioni in Kaggle e fai clic su Crea nuovo token. Viene attivato il download di un file kaggle.json contenente le tue credenziali API.

Poi, in Colab, seleziona Secrets (🔑) nel riquadro a sinistra e aggiungi il tuo nome utente e la tua chiave API Kaggle. Memorizza il tuo nome utente con il nome KAGGLE_USERNAME e la tua chiave API con il nome KAGGLE_KEY.

Seleziona il runtime

Per completare questo tutorial, devi disporre di un runtime Colab con risorse sufficienti per eseguire il modello PaliGemma. In questo caso, puoi utilizzare una GPU T4:

  1. In alto a destra nella finestra di Colab, fai clic sul menu a discesa ▾ (Opzioni di connessione aggiuntive).
  2. Seleziona Cambia tipo di runtime.
  3. In Acceleratore hardware, seleziona GPU T4.

Imposta le variabili di ambiente

Imposta le variabili di ambiente per KAGGLE_USERNAME, KAGGLE_KEY e KERAS_BACKEND.

import os
from google.colab import userdata

# Set up environmental variables
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"

Installa Keras

Esegui la cella di seguito per installare Keras.

pip install -U -q keras-nlp keras-hub kagglehub

Importa le dipendenze e configura Keras

Installa le dipendenze necessarie per questo notebook e configura il backend di Keras. Imposterai inoltre Keras in modo che utilizzi bfloat16 in modo che il framework utilizzi meno memoria.

import keras
import keras_hub
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

keras.config.set_floatx("bfloat16")

Carica il modello

Ora che hai configurato tutto, puoi scaricare il modello preaddestrato e creare alcuni metodi di utilità per aiutare il modello a generare le sue risposte. In questo passaggio, scarichi un modello utilizzando PaliGemmaCausalLM da Keras Hub. Questo corso ti aiuta a gestire ed eseguire la struttura del modello linguistico visivo causale di PaliGemma. Un modello linguistico visivo causale prevede il token successivo in base ai token precedenti. Keras Hub fornisce implementazioni di molte architetture di modelli popolari.

Crea il modello utilizzando il metodo from_preset e stampa il relativo riepilogo. Il completamento di questa procedura richiede circa un minuto.

paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset("kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224")
paligemma.summary()

Crea metodi di utilità

Per aiutarti a generare risposte dal modello, crea due metodi di utilità:

  • crop_and_resize: metodo di supporto per read_img. Questo metodo ritaglia e ridimensiona l'immagine in base alle dimensioni passate in modo che l'immagine finale venga ridimensionata senza alterare le proporzioni.
  • read_img: metodo di supporto per read_img_from_url. Questo metodo è quello che apre effettivamente l'immagine, la ridimensiona in modo che rientri nei vincoli del modello e la inserisce in un array che può essere interpretato dal modello.
  • read_img_from_url: acquisisce un'immagine tramite un URL valido. Questo metodo è necessario per passare l'immagine al modello.

Utilizzerai read_img_from_url nel passaggio successivo di questo notebook.

def crop_and_resize(image, target_size):
    width, height = image.size
    source_size = min(image.size)
    left = width // 2 - source_size // 2
    top = height // 2 - source_size // 2
    right, bottom = left + source_size, top + source_size
    return image.resize(target_size, box=(left, top, right, bottom))

def read_image(url, target_size):
    contents = io.BytesIO(requests.get(url).content)
    image = PIL.Image.open(contents)
    image = crop_and_resize(image, target_size)
    image = np.array(image)
    # Remove alpha channel if necessary.
    if image.shape[2] == 4:
        image = image[:, :, :3]
    return image

def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>.+?)( ;|$)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

def display_boxes(image, boxes, labels, target_image_size):
  h, l = target_size
  fig, ax = plt.subplots()
  ax.imshow(image)
  for i in range(boxes.shape[0]):
      y, x, y2, x2 = (boxes[i]*h)
      width = x2 - x
      height = y2 - y
      # Create a Rectangle patch
      rect = patches.Rectangle((x, y),
                               width,
                               height,
                               linewidth=1,
                               edgecolor='r',
                               facecolor='none')
      # Add label
      plt.text(x, y, labels[i], color='red', fontsize=12)
      # Add the patch to the Axes
      ax.add_patch(rect)

  plt.show()

def display_segment_output(image, bounding_box, segment_mask, target_image_size):
    # Initialize a full mask with the target size
    full_mask = np.zeros(target_image_size, dtype=np.uint8)
    target_width, target_height = target_image_size

    for bbox, mask in zip(bounding_box, segment_mask):
        y1, x1, y2, x2 = bbox
        x1 = int(x1 * target_width)
        y1 = int(y1 * target_height)
        x2 = int(x2 * target_width)
        y2 = int(y2 * target_height)

        # Ensure mask is 2D before converting to Image
        if mask.ndim == 3:
            mask = mask.squeeze(axis=-1)
        mask = Image.fromarray(mask)
        mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)
        mask = np.array(mask)
        binary_mask = (mask > 0.5).astype(np.uint8)


        # Place the binary mask onto the full mask
        full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)
    cmap = plt.get_cmap('jet')
    colored_mask = cmap(full_mask / 1.0)
    colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)
    if isinstance(image, Image.Image):
        image = np.array(image)
    blended_image = image.copy()
    mask_indices = full_mask > 0
    alpha = 0.5

    for c in range(3):
        blended_image[:, :, c] = np.where(mask_indices,
                                          (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],
                                          image[:, :, c])

    fig, ax = plt.subplots()
    ax.imshow(blended_image)
    plt.show()

Genera output

Dopo aver caricato il modello e creato i metodi di utilità, puoi chiedere al modello di generare risposte con dati di immagini e testo. I modelli PaliGemma vengono addestrati con una sintassi del prompt specifica per attività specifiche, come answer, caption e detect. Per ulteriori informazioni sulla sintassi del prompt di PaliGemma, consulta le istruzioni del prompt e del sistema di PaliGemma.

Prepara un'immagine per l'utilizzo in un prompt di generazione utilizzando il seguente codice per caricare un'immagine di prova in un oggetto:

target_size = (224, 224)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
cow_image = read_image(image_url, target_size)
matplotlib.pyplot.imshow(cow_image)

Rispondere in una lingua specifica

Il seguente codice di esempio mostra come richiedere al modello PaliGemma informazioni su un oggetto visualizzato in un'immagine fornita. Questo esempio utilizza la sintassi answer {lang} e mostra altre domande in altre lingue:

prompt = 'answer en where is the cow standing?\n'
# prompt = 'svar no hvor står kuen?\n'
# prompt = 'answer fr quelle couleur est le ciel?\n'
# prompt = 'responda pt qual a cor do animal?\n'

output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
print(output)

Usare il prompt detect

Il seguente codice di esempio utilizza la sintassi del prompt detect per individuare un oggetto nell'immagine fornita. Il codice utilizza le funzioni parse_bbox_and_labels() e display_boxes() definite in precedenza per interpretare l'output del modello e visualizzare le caselle delimitanti generate.

prompt = 'detect cow\n'
output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
boxes, labels = parse_bbox_and_labels(output)
display_boxes(cow_image, boxes, labels, target_size)

Usare il prompt segment

Il seguente codice di esempio utilizza la sintassi del prompt segment per individuare l'area di un'immagine occupata da un oggetto. Utilizza la libreria big_vision di Google per interpretare l'output del modello e generare una maschera per l'oggetto segmentato.

Prima di iniziare, installa la libreria big_vision e le relative dipendenze, come mostrato in questo esempio di codice:

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")


# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

Per questo esempio di segmentazione, carica e prepara un'altra immagine che includa un gatto.

cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)
matplotlib.pyplot.imshow(cat)

Ecco una funzione per aiutarti ad analizzare l'output del segmento di PaliGemma

import  big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      + ''.join(f'<seg(?P<s{i}>\d\d\d)>' for i in range(16)),
      detokenized_output,
  )
  boxes, segs = [], []
  fmt_box = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])
    segs.append([int(d[f's{i}']) for i in range(16)])
  return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))

Esegui una query su PaliGemma per segmentare il gatto nell'immagine

prompt = 'segment cat\n'
output = paligemma.generate(
    inputs={
        "images": cat,
        "prompts": prompt,
    }
)

Visualizza la maschera generata da PaliGemma

bboxes, seg_masks = parse_segments(output)
display_segment_output(cat, bboxes, seg_masks, target_size)

Prompt batch

Puoi fornire più di un comando prompt all'interno di un singolo prompt come un batch di istruzioni. Il seguente esempio mostra come strutturare il testo del prompt per fornire più istruzioni.

prompts = [
    'answer en where is the cow standing?\n',
    'answer en what color is the cow?\n',
    'describe en\n',
    'detect cow\n',
    'segment cow\n',
]
images = [cow_image, cow_image, cow_image, cow_image, cow_image]
outputs = paligemma.generate(
    inputs={
        "images": images,
        "prompts": prompts,
    }
)
for output in outputs:
    print(output)