Generare l'output di PaliGemma con Keras

Visualizza su ai.google.dev Esegui in Google Colab Apri in Vertex AI Visualizza il codice sorgente su GitHub

I modelli PaliGemma hanno funzionalità multimodali, che ti consentono di generare output utilizzando dati di input di testo e immagini. Puoi utilizzare i dati delle immagini con questi modelli per fornire un contesto aggiuntivo alle tue richieste o utilizzare il modello per analizzare il contenuto delle immagini. Questo tutorial mostra come utilizzare PaliGemma con Keras per analizzare le immagini e rispondere alle domande.

Contenuto di questo notebook

Questo blocco note utilizza PaliGemma con Keras e mostra come:

  • Installa Keras e le dipendenze richieste
  • Scarica PaliGemmaCausalLM, una variante preaddestrata di PaliGemma per la modellazione causale del linguaggio visivo, e utilizzala per creare un modello
  • Testare la capacità del modello di dedurre informazioni sulle immagini fornite

Prima di iniziare

Prima di esaminare questo blocco note, devi avere familiarità con il codice Python e con l'addestramento dei modelli linguistici di grandi dimensioni (LLM). Non è necessario avere familiarità con Keras, ma una conoscenza di base è utile per leggere il codice di esempio.

Configurazione

Le sezioni seguenti spiegano i passaggi preliminari per fare in modo che un notebook utilizzi un modello PaliGemma, tra cui l'accesso al modello, l'ottenimento di una chiave API e la configurazione del runtime del notebook.

Accedere a PaliGemma

Prima di utilizzare PaliGemma per la prima volta, devi richiedere l'accesso al modello tramite Kaggle completando i seguenti passaggi:

  1. Accedi a Kaggle o crea un nuovo account Kaggle se non ne hai già uno.
  2. Vai alla scheda del modello PaliGemma e fai clic su Richiedi accesso.
  3. Compila il modulo di consenso e accetta i Termini e condizioni.

Configura la chiave API

Per utilizzare PaliGemma, devi fornire il tuo nome utente Kaggle e una chiave API Kaggle.

Per generare una chiave API Kaggle, apri la pagina Impostazioni in Kaggle e fai clic su Crea nuovo token. Viene attivato il download di un file kaggle.json contenente le credenziali API.

Poi, in Colab, seleziona Secrets (🔑) nel riquadro a sinistra e aggiungi il tuo nome utente Kaggle e la chiave API Kaggle. Memorizza il tuo nome utente con il nome KAGGLE_USERNAME e la tua chiave API con il nome KAGGLE_KEY.

Seleziona il runtime

Per completare questo tutorial, devi disporre di un runtime Colab con risorse sufficienti per eseguire il modello PaliGemma. In questo caso, puoi utilizzare una GPU T4:

  1. In alto a destra nella finestra di Colab, fai clic sul menu a discesa ▾ (Opzioni di connessione aggiuntive).
  2. Seleziona Cambia tipo di runtime.
  3. In Acceleratore hardware, seleziona GPU T4.

Imposta le variabili di ambiente

Imposta le variabili di ambiente per KAGGLE_USERNAME, KAGGLE_KEY e KERAS_BACKEND.

import os
from google.colab import userdata

# Set up environmental variables
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KERAS_BACKEND"] = "jax"

Installa Keras

Esegui la cella seguente per installare Keras.

pip install -U -q keras-nlp keras-hub kagglehub

Importa le dipendenze e configura Keras

Installa le dipendenze necessarie per questo blocco note e configura il backend di Keras. Imposterai anche Keras in modo che utilizzi bfloat16, in modo che il framework utilizzi meno memoria.

import keras
import keras_hub
import numpy as np
import PIL
import requests
import io
import matplotlib
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

keras.config.set_floatx("bfloat16")

Carica il modello

Ora che hai configurato tutto, puoi scaricare il modello preaddestrato e creare alcuni metodi di utilità per aiutare il modello a generare le risposte. In questo passaggio, scarichi un modello utilizzando PaliGemmaCausalLM da Keras Hub. Questa classe ti aiuta a gestire ed eseguire la struttura del modello linguistico visivo causale di PaliGemma. Un modello linguistico visivo causale prevede il token successivo in base ai token precedenti. Keras Hub fornisce implementazioni di molte architetture di modelli popolari.

Crea il modello utilizzando il metodo from_preset e stampa il relativo riepilogo. Il completamento di questa procedura richiede circa un minuto.

paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset("kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224")
paligemma.summary()

Crea metodi di utilità

Per aiutarti a generare risposte dal modello, crea due metodi di utilità:

  • crop_and_resize: metodo helper per read_img. Questo metodo ritaglia e ridimensiona l'immagine in base alle dimensioni trasmesse, in modo che l'immagine finale venga ridimensionata senza distorcere le proporzioni.
  • read_img: metodo helper per read_img_from_url. Questo metodo apre effettivamente l'immagine, la ridimensiona in modo che rientri nei vincoli del modello e la inserisce in un array che può essere interpretato dal modello.
  • read_img_from_url: accetta un'immagine tramite un URL valido. Questo metodo è necessario per passare l'immagine al modello.

Utilizzerai read_img_from_url nel passaggio successivo di questo blocco note.

def crop_and_resize(image, target_size):
    width, height = image.size
    source_size = min(image.size)
    left = width // 2 - source_size // 2
    top = height // 2 - source_size // 2
    right, bottom = left + source_size, top + source_size
    return image.resize(target_size, box=(left, top, right, bottom))

def read_image(url, target_size):
    contents = io.BytesIO(requests.get(url).content)
    image = PIL.Image.open(contents)
    image = crop_and_resize(image, target_size)
    image = np.array(image)
    # Remove alpha channel if necessary.
    if image.shape[2] == 4:
        image = image[:, :, :3]
    return image

def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>.+?)( ;|$)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

def display_boxes(image, boxes, labels, target_image_size):
  h, l = target_size
  fig, ax = plt.subplots()
  ax.imshow(image)
  for i in range(boxes.shape[0]):
      y, x, y2, x2 = (boxes[i]*h)
      width = x2 - x
      height = y2 - y
      # Create a Rectangle patch
      rect = patches.Rectangle((x, y),
                               width,
                               height,
                               linewidth=1,
                               edgecolor='r',
                               facecolor='none')
      # Add label
      plt.text(x, y, labels[i], color='red', fontsize=12)
      # Add the patch to the Axes
      ax.add_patch(rect)

  plt.show()

def display_segment_output(image, bounding_box, segment_mask, target_image_size):
    # Initialize a full mask with the target size
    full_mask = np.zeros(target_image_size, dtype=np.uint8)
    target_width, target_height = target_image_size

    for bbox, mask in zip(bounding_box, segment_mask):
        y1, x1, y2, x2 = bbox
        x1 = int(x1 * target_width)
        y1 = int(y1 * target_height)
        x2 = int(x2 * target_width)
        y2 = int(y2 * target_height)

        # Ensure mask is 2D before converting to Image
        if mask.ndim == 3:
            mask = mask.squeeze(axis=-1)
        mask = Image.fromarray(mask)
        mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)
        mask = np.array(mask)
        binary_mask = (mask > 0.5).astype(np.uint8)


        # Place the binary mask onto the full mask
        full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)
    cmap = plt.get_cmap('jet')
    colored_mask = cmap(full_mask / 1.0)
    colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)
    if isinstance(image, Image.Image):
        image = np.array(image)
    blended_image = image.copy()
    mask_indices = full_mask > 0
    alpha = 0.5

    for c in range(3):
        blended_image[:, :, c] = np.where(mask_indices,
                                          (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],
                                          image[:, :, c])

    fig, ax = plt.subplots()
    ax.imshow(blended_image)
    plt.show()

Genera output

Dopo aver caricato il modello e creato i metodi di utilità, puoi chiedere al modello di generare risposte con dati di immagini e testo. I modelli PaliGemma vengono addestrati con una sintassi di prompt specifica per attività specifiche, come answer, caption e detect. Per saperne di più sulla sintassi delle attività del prompt di PaliGemma, consulta Prompt e istruzioni di sistema di PaliGemma.

Prepara un'immagine da utilizzare in un prompt di generazione utilizzando il seguente codice per caricare un'immagine di test in un oggetto:

target_size = (224, 224)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
cow_image = read_image(image_url, target_size)
matplotlib.pyplot.imshow(cow_image)

Rispondere in una lingua specifica

Il seguente codice di esempio mostra come richiedere al modello PaliGemma informazioni su un oggetto che appare in un'immagine fornita. Questo esempio utilizza la sintassi answer {lang} e mostra domande aggiuntive in altre lingue:

prompt = 'answer en where is the cow standing?\n'
# prompt = 'svar no hvor står kuen?\n'
# prompt = 'answer fr quelle couleur est le ciel?\n'
# prompt = 'responda pt qual a cor do animal?\n'

output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
print(output)

Utilizza il prompt detect

Il seguente codice di esempio utilizza la sintassi del prompt detect per individuare un oggetto nell'immagine fornita. Il codice utilizza le funzioni parse_bbox_and_labels() e display_boxes() definite in precedenza per interpretare l'output del modello e visualizzare i riquadri di selezione generati.

prompt = 'detect cow\n'
output = paligemma.generate(
    inputs={
        "images": cow_image,
        "prompts": prompt,
    }
)
boxes, labels = parse_bbox_and_labels(output)
display_boxes(cow_image, boxes, labels, target_size)

Utilizza il prompt segment

Il seguente codice di esempio utilizza la sintassi del prompt segment per individuare l'area di un'immagine occupata da un oggetto. Utilizza la libreria Google big_vision per interpretare l'output del modello e generare una maschera per l'oggetto segmentato.

Prima di iniziare, installa la libreria big_vision e le relative dipendenze, come mostrato in questo esempio di codice:

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")


# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

Per questo esempio di segmentazione, carica e prepara un'immagine diversa che includa un gatto.

cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)
matplotlib.pyplot.imshow(cat)

Ecco una funzione per analizzare l'output del segmento di PaliGemma

import  big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      + ''.join(f'<seg(?P<s{i}>\d\d\d)>' for i in range(16)),
      detokenized_output,
  )
  boxes, segs = [], []
  fmt_box = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])
    segs.append([int(d[f's{i}']) for i in range(16)])
  return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))

Esegui una query su PaliGemma per segmentare il gatto nell'immagine

prompt = 'segment cat\n'
output = paligemma.generate(
    inputs={
        "images": cat,
        "prompts": prompt,
    }
)

Visualizzare la maschera generata da PaliGemma

bboxes, seg_masks = parse_segments(output)
display_segment_output(cat, bboxes, seg_masks, target_size)

Prompt batch

Puoi fornire più di un comando prompt all'interno di un singolo prompt come batch di istruzioni. L'esempio seguente mostra come strutturare il testo del prompt per fornire più istruzioni.

prompts = [
    'answer en where is the cow standing?\n',
    'answer en what color is the cow?\n',
    'describe en\n',
    'detect cow\n',
    'segment cow\n',
]
images = [cow_image, cow_image, cow_image, cow_image, cow_image]
outputs = paligemma.generate(
    inputs={
        "images": images,
        "prompts": prompts,
    }
)
for output in outputs:
    print(output)