Përmirëso PaliGemma-n me JAX dhe Flax

Shiko në ai.google.dev Ekzekuto në Google Colab Vraponi në Kaggle Hap në Vertex AI Shiko burimin në GitHub

Ky fletore tregon se si të akordoni imët PaliGemma në një detyrë të gjuhës vizuale me JAX . Akordimi i imët është një proces që mund të përmirësojë performancën e modelit tuaj në detyra specifike ose të ndihmojë modelin t'i përmbahet kërkesave specifike të daljes kur udhëzimet nuk janë të mjaftueshme dhe ju keni një sërë shembujsh që demonstrojnë daljet që dëshironi. Modelet e bazuara në Gemma si PaliGemma kërkojnë akordim të imët për të prodhuar rezultatet e pritura.

Çfarë ka në këtë fletore

Ky fletore përdor implementimin e referencës së modelit nga big_vision dhe tregon se si të:

  • Instaloni varësitë dhe shkarkoni pikën e kontrollit të modelit PaliGemma dhe të dhënat e trajnimit
  • Ngarko modelin në pajisjet GPU
  • Përgatitni të dhënat hyrëse të modelit për trajnim dhe nxjerrje përfundimesh
  • Përmirësoni modelin
  • Inspektoni rezultatin

Të dhënat e trajnimit për këtë fletore përbëhen nga 90 çifte imazhesh dhe mbishkrime të gjata që i përshkruajnë ato. Për ta bërë të ekzekutueshëm në një kohë ekzekutimi T4 colab, do të rregulloni vetëm shtresat e vëmendjes së modelit gjuhësor dhe do të ngrini parametrat e tjerë.

Ky shembull është vetëm për qëllime mësimi. Në një rast përdorimi real, sasia e të dhënave, parametrat e trajnueshëm, hapat e trajnimit dhe hiperparametrat, si dhe rezultatet e përftuara mund të jenë dukshëm të ndryshme.

Para se të filloni

Para se të lexoni këtë fletore, duhet të jeni të njohur me kodin Python, si dhe me mënyrën se si trajnohen modelet e mëdha gjuhësore (LLM). Nuk keni nevojë të jeni të njohur me JAX, por njohuritë bazë rreth JAX (ose teknologjive të ngjashme si Keras) janë të dobishme kur lexoni kodin shembullor.

Konfigurimi

Seksionet e mëposhtme shpjegojnë hapat paraprakë për ta bërë një fletore kompjuterike të përdorë një model PaliGemma, duke përfshirë aksesin në model, marrjen e një çelësi API dhe konfigurimin e kohës së ekzekutimit të fletores.

Merrni qasje në PaliGemma

Para se të përdorni PaliGemma për herë të parë, duhet të kërkoni qasje në model përmes Kaggle duke përfunduar hapat e mëposhtëm:

  1. Kyçu në Kaggle ose krijo një llogari të re në Kaggle nëse nuk ke një të tillë.
  2. Shko te karta e modelit PaliGemma dhe kliko te Kërkesë për Qasje .
  3. Plotësoni formularin e pëlqimit dhe pranoni termat dhe kushtet.

Konfiguro çelësin tënd API

Për të përdorur PaliGemma, duhet të jepni emrin tuaj të përdoruesit Kaggle dhe një çelës API të Kaggle.

Për të gjeneruar një çelës API të Kaggle, hapni faqen e Cilësimeve në Kaggle dhe klikoni Krijo Token të Ri . Kjo aktivizon shkarkimin e një skedari kaggle.json që përmban kredencialet tuaja të API-t.

Pastaj, në Colab, zgjidhni Sekretet (🔑) në panelin e majtë dhe shtoni emrin e përdoruesit dhe çelësin API të Kaggle. Ruajeni emrin e përdoruesit nën emrin KAGGLE_USERNAME dhe çelësin API nën emrin KAGGLE_KEY .

Zgjidhni kohën e ekzekutimit

Për të përfunduar këtë tutorial, do t'ju duhet të keni një Colab runtime me burime të mjaftueshme për të ekzekutuar modelin PaliGemma. Në këtë rast, mund të përdorni një GPU T4:

  1. Në pjesën e sipërme djathtas të dritares Colab, klikoni në menynë zbritëse ▾ (Opsione shtesë lidhjeje) .
  2. Zgjidhni Ndrysho llojin e kohës së ekzekutimit .
  3. Nën Përshpejtuesin e Pajisjeve , zgjidhni GPU-në T4 .

Instaloni paketat Python

Ekzekutoni qelizën më poshtë për të instaluar KaggleHub.

pip install -U -q kagglehub

Vendos variablat e mjedisit

Vendosni variablat e mjedisit dhe hyrjen në Kaggle.

import os
# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid out-of-memory due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

# This will prompt you to enter your Kaggle API token
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…
Kaggle credentials set.
Kaggle credentials successfully validated.

Shkarkoni repozitorin big_vision në fletoren tuaj të shënimeve Colab nga GitHub dhe instaloni varësitë që lidhen me big_vision duke ekzekutuar kodin e mëposhtëm.

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.7/76.7 kB 2.8 MB/s eta 0:00:00

Importo JAX dhe varësi të tjera

Importoni JAX dhe varësi të tjera të nevojshme për PaliGemma, si TensorFlow dhe NumPy.

import base64
import functools
import html
import io
import os
import warnings

import jax
import jax.numpy as jnp
import numpy as np
import ml_collections

import tensorflow as tf
import sentencepiece

from IPython.core.display import display, HTML
from PIL import Image

# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns

# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding

# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

backend = jax.extend.backend.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")
JAX version:  0.7.2
JAX platform: gpu
JAX devices:  1

Shkarkoni dhe konfiguroni modelin

Në këtë hap, do të shkarkoni pikën e kontrollit të modelit dhe do ta konfiguroni atë në mënyrë që ta konfiguroni më vonë. Ky hap ju tregon se si të zhvendosni parametrat e modelit në memorien TPU, gjë që është e dobishme për akordimin e imët të modeleve në pajisje me burime të kufizuara.

Shkarkoni pikën e kontrollit të modelit

PaliGemma përfshin disa variacione modeli. Për këtë tutorial, do të përdorni modelin bazë të peshës JAX/FLAX PaliGemma 3B .

Shkarkoni pikën e kontrollit të modelit nga Kaggle duke ekzekutuar kodin e mëposhtëm. Ky proces zgjat disa minuta për t'u përfunduar.

import os
import kagglehub

# Use these for PaliGemma-2 3B 224px²
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "./paligemma2-3b-pt-224.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224"  # Path to fetch from Kaggle.

# Use these for PaliGemma 1:
# LLM_VARIANT = "gemma_2b"
# MODEL_PATH = "./paligemma-3b-pt-224.f16.npz"
# KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"

if not os.path.exists(MODEL_PATH):
  print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
  MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
  print(f"Model path: {MODEL_PATH}")

TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")

DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
  print("Downloading the dataset...")
  !gsutil -m -q cp -n -r gs://longcap100/ .
  print(f"Data path: {DATA_DIR}")
Downloading the checkpoint from Kaggle, this could take a few minutes....
Downloading to /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz...
100%|██████████| 5.65G/5.65G [00:54<00:00, 112MB/s]
Model path: /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz
Downloading the model tokenizer...
Copying gs://big_vision/paligemma_tokenizer.model...

- [1 files][  4.1 MiB/  4.1 MiB]                                                
Operation completed over 1 objects/4.1 MiB.                                      
Tokenizer path: ./paligemma_tokenizer.model
Downloading the dataset...
Data path: ./longcap100

Konfiguro modelin

Është koha për të filluar konfigurimin e modelit që do të përdorni.

Për këtë laptop, duhet të jeni në gjendje ta përshtatni modelin tuaj në një GPU T4. Duke pasur burime të kufizuara, si kufizime hapësinore, duhet të jeni të kujdesshëm se si është konfiguruar modeli juaj.

Nëse i rregulloni imët çdo parametër, modeli juaj nuk do të jetë në gjendje të funksionojë në mjedisin e fletores. Si rezultat, në këtë pjesë të fletores, do ta konfiguroni modelin tuaj në mënyrë që të ketë aftësinë të ngrijë disa nga parametrat dhe të rregullojë vetëm parametrat që duhet të rregullohen vërtet që modeli t'ju japë rezultate të sakta. Në LLM-të, parametrat thuhet se janë të ngrirë kur nuk përdoren më në mënyrë aktive për të trajnuar modelin.

Për të konfiguruar modelin tuaj, duhet të:

  • Inicializoni model_config si një FrozenConfigDict në mënyrë që të ngrini disa nga parametrat dhe të mbani përdorimin e memories të ulët.
  • Inicializoni një instancë të klasës PaliGemma Model duke përdorur model_config si konfigurime të saj.
  • Ngarko parametrat e modelit në RAM
  • Përcaktoni një funksion decode për të marrë mostra nga rezultatet e modelit

Ky kod në këtë qelizë ekzekutohet për rreth një minutë.

# Define model

# IMPORTANT: Gemma-2 has a "final_logits_softcap" property. Set it to 0.0
# for better transfer results.
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)

# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

Zhvendos parametrat e modelit në memorien GPU/TPU

Tani duhet të zhvendosni parametrat e modelit në memorien GPU/TPU. Së pari, ndani parametrat në të gjitha GPU-të e disponueshme dhe më pas ngarkoni parametrat. Këtu, do t'i ngarkoni parametrat në mënyrë sekuenciale. Ky proces zgjat më shumë sesa ngarkimi i tyre njëkohësisht, por kërkon më shumë RAM sesa keni në dispozicion në këtë laptop.

Së fundmi, printoni të gjithë parametrat për të parë se në çfarë lloji është paraqitur secili parametër individual. Parametrat e ngrirë mbahen si float16 , ndërsa parametrat e trajnueshëm paraqiten në float32 . Kur ta inspektoni listën, do të shihni se shumica e parametrave janë ngrirë dhe janë float16 .

# Create a pytree mask of the trainable params.
def is_trainable_param(name, param):  # pylint: disable=unused-argument
  if name.startswith("llm/layers/attn/"):  return True
  if name.startswith("llm/"):              return False
  if name.startswith("img/"):              return False
  raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)

# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))

data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  # Cast others to float16, since some GPUs don't support bf16.
  return jax.tree.map(lambda p, m: p.astype(jnp.float32)
                      if m else p.astype(jnp.float16),
                      params, trainable)

# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default.
# Instead, do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
  params[idx] = big_vision.utils.reshard(params[idx], sharding)
  params[idx] = maybe_cast_to_f32(params[idx], trainable)
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

# Print params to show what the model is made of.
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")

print(" == Model params == ")
parameter_overview(params)
== Model params == 
img/Transformer/encoder_norm/bias                                                (1152,)                float16
img/Transformer/encoder_norm/scale                                               (1152,)                float16
img/Transformer/encoderblock/LayerNorm_0/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_0/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias                             (27, 4304)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel                           (27, 1152, 4304)       float16
img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias                             (27, 1152)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel                           (27, 4304, 1152)       float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias             (27, 16, 72)           float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel           (27, 1152, 16, 72)     float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias             (27, 1152)             float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel           (27, 16, 72, 1152)     float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias           (27, 16, 72)           float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel         (27, 1152, 16, 72)     float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias           (27, 16, 72)           float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel         (27, 1152, 16, 72)     float16
img/embedding/bias                                                               (1152,)                float16
img/embedding/kernel                                                             (14, 14, 3, 1152)      float16
img/head/bias                                                                    (2304,)                float16
img/head/kernel                                                                  (1152, 2304)           float16
img/pos_embedding                                                                (1, 256, 1152)         float16
llm/embedder/input_embedding                                                     (257152, 2304)         float16
llm/final_norm/scale                                                             (2304,)                float16
llm/layers/attn/attn_vec_einsum/w                                                (26, 8, 256, 2304)     float32
llm/layers/attn/kv_einsum/w                                                      (26, 2, 4, 2304, 256)  float32
llm/layers/attn/q_einsum/w                                                       (26, 8, 2304, 256)     float32
llm/layers/mlp/gating_einsum                                                     (26, 2, 2304, 9216)    float16
llm/layers/mlp/linear                                                            (26, 9216, 2304)       float16
llm/layers/post_attention_norm/scale                                             (26, 2304)             float16
llm/layers/post_ffw_norm/scale                                                   (26, 2304)             float16
llm/layers/pre_attention_norm/scale                                              (26, 2304)             float16
llm/layers/pre_ffw_norm/scale                                                    (26, 2304)             float16

Përgatituni për të akorduar modelin

Tani që modeli juaj është konfiguruar, mund ta konfiguroni atë. Në këtë hap, do të krijoni të dhënat hyrëse të modelit tuaj, si dhe iteratorët e trajnimit dhe validimit, do të shikoni shembujt e trajnimit dhe do të përcaktoni sythet e trajnimit dhe validimit.

Krijo hyrje modeli

Pika e kontrollit të modelit që po përdorni është trajnuar tashmë në imazhe me raporte të ndryshme aspektesh që janë ridimensionuar në 224x224 piksel, dhe për të trajtuar tekste të tokenizuara.

Kodi më poshtë përcakton tre funksione që do t'i përdorni në hapin tjetër për të krijuar të dhënat hyrëse të modelit:

  • preprocess_image : Normalizon të dhënat e imazhit. Në këtë rast, para-përpunimi e konverton imazhin e kaluar në shkallë gri, heq shtresën alfa dhe e ridimensionon imazhin e kaluar në madhësinë e kërkuar nga modeli për hyrjet e imazhit (224x224 piksel).
  • preprocess_tokens : Ndan tokenët dhe shton flamuj për të shënuar nëse një token është një token me parashtesë apo prapashtesë. Këta flamuj do të përdoren më vonë në kod, gjatë hapit të trajnimit dhe ciklit të vlerësimit.
  • postprocess_tokens : Heq çdo token të mbetur në dhe/ose pas tokenit të fundit të sekuencës (EOS) dhe kthen tokenët e mbetur të dekoduar.
def preprocess_image(image, size=224):
  # Model has been trained to handle images of different aspects ratios
  # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
  # options are helpful to improve quality in some tasks.
  image = np.asarray(image)
  if image.ndim == 2:  # Convert image without last channel into greyscale.
    image = np.stack((image,)*3, axis=-1)
  image = image[..., :3]  # Remove alpha layer.
  assert image.shape[-1] == 3

  image = tf.constant(image)
  image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
  return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]

def preprocess_tokens(prefix, suffix=None, seqlen=None):
  # Model has been trained to handle tokenized text composed of a prefix with
  # full attention and a suffix with causal attention.
  separator = "\n"
  tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
  mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.
  mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.

  if suffix:
    suffix = tokenizer.encode(suffix, add_eos=True)
    tokens += suffix
    mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.
    mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.

  mask_input = [1] * len(tokens)    # 1 if it's a token, 0 if padding.
  if seqlen:
    padding = [0] * max(0, seqlen - len(tokens))
    tokens = tokens[:seqlen] + padding
    mask_ar = mask_ar[:seqlen] + padding
    mask_loss = mask_loss[:seqlen] + padding
    mask_input = mask_input[:seqlen] + padding

  return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

def postprocess_tokens(tokens):
  tokens = tokens.tolist()  # np.array to list[int]
  try:  # Remove tokens at and after EOS if any.
    eos_pos = tokens.index(tokenizer.eos_id())
    tokens = tokens[:eos_pos]
  except ValueError:
    pass
  return tokenizer.decode(tokens)

Krijoni iteratorët e trajnimit dhe validimit

Krijoni dy iteratorë:

  • Një iterator trajnimi për të lejuar që procesi i trajnimit të kalojë nëpër të dhënat në copa në vend që t'i përpunojë të gjitha menjëherë
    • Kjo ju lejon të bëni disa përpunime paraprake të të dhënave para përdorimit.
  • Një iterator validimi që lejon procesin e trajnimit të përsërisë mbi të dhënat e validimit për të parë se sa mirë modeli i akorduar përputhet me rezultatet e dhëna.
SEQLEN = 128

train_dataset = big_vision.datasets.jsonl.DataSource(
    os.path.join(DATA_DIR, "data_train90.jsonl"),
    fopen_keys={"image": DATA_DIR})

val_dataset = big_vision.datasets.jsonl.DataSource(
    os.path.join(DATA_DIR, "data_val10.jsonl"),
    fopen_keys={"image": DATA_DIR})


def train_data_iterator():
  """Never ending iterator over training examples."""
  # Shuffle examples and repeat so one can train for many epochs.
  dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()
  for example in dataset.as_numpy_iterator():
    image = Image.open(io.BytesIO(example["image"]))
    image = preprocess_image(image)

    prefix = "caption en"  # Could also be a different prefix per example.
    suffix = example["suffix"].decode().lower()
    tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)

    yield {
        "image": np.asarray(image),
        "text": np.asarray(tokens),
        "mask_ar": np.asarray(mask_ar),
        "mask_loss": np.asarray(mask_loss),
    }


def validation_data_iterator():
  """Single iterator over validation examples."""
  for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():
    image = Image.open(io.BytesIO(example["image"]))
    image = preprocess_image(image)

    prefix = "caption en"  # Could also be a different prefix per example.
    tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)

    yield {
        "image": np.asarray(image),
        "text": np.asarray(tokens),
        "mask_ar": np.asarray(mask_ar),
        "mask_input": np.asarray(mask_input),
    }

Shikoni shembuj trajnimi

Në këtë fletore, të dhënat e trajnimit përmbajnë 90 imazhe që shoqërohen me përshkrime të gjata të asaj që përshkruhet në imazh.

Kodi më poshtë printon një përzgjedhje të rastësishme imazhesh me përshkrimet e tyre nga grupi i të dhënave të trajnimit, në mënyrë që të shihni se si duken imazhet dhe përshkrimet mbi të cilat është trajnuar modeli juaj. Çdo imazh shfaqet si një JPEG me 128x128 piksel, me përshkrimin e printuar pranë imazhit në të djathtë.

def render_inline(image, resize=(128, 128)):
  """Convert image into inline html."""
  image = Image.fromarray(image)
  image.resize(resize)
  with io.BytesIO() as buffer:
    image.save(buffer, format='jpeg')
    image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
    return f"data:image/jpeg;base64,{image_b64}"

def render_example(image, caption):
  image = ((image + 1)/2 * 255).astype(np.uint8)  # [-1,1] -> [0, 255]
  return f"""
    <div style="display: inline-flex; align-items: center; justify-content: center;">
        <img style="width:128px; height:128px;" src="{render_inline(image, resize=(64,64))}" />
        <p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
    </div>
    """

html_out = ""
for idx, example in zip(range(8), train_data_iterator()):
  caption = postprocess_tokens(example["text"])  # detokenize model input.
  caption = caption[len("caption en\n"):]        # strip prefix
  html_out += render_example(example["image"], caption)

print("Training examples")
display(HTML(html_out))
Training examples

Përcaktoni ciklet e trajnimit dhe vlerësimit

Përcaktoni ciklin e trajnimit për të trajnuar modelin në të dhënat e dhëna dhe ciklin e vlerësimit për të parë të gjithë shembujt në të dhënat e validimit dhe për të bërë parashikimet e tij.

Përcaktimi i ciklit të trajnimit

Funksioni update_fn përcakton hapin e trajnimit. Gjatë hapit të trajnimit, humbja për shembull llogaritet dhe zbritja stokastike e gradientit (SGD) zbatohet në parametrat e trajnueshëm.

Kujtoni se më herët në fletore, keni përfshirë flamuj në funksionin preprocess_tokens që përfshinin mask_loss . Do ta përdorni flamurin mask_loss këtu për të përjashtuar prefiksin dhe tokenët e mbushur nga humbja. Pa të, llogaritja e humbjes do të jetë e shtrembër. Gjithashtu duhet të normalizoni çdo shembull, pasi secili prej tyre ka një numër të ndryshëm tokenësh. Pasi prefiksi dhe tokenët e mbushur të jenë përjashtuar dhe shembujt të jenë normalizuar, mund të llogaritni humbjen për shembull.

Hapi i trajnimit përfshin gjithashtu një funksion për të aplikuar një SGD për të optimizuar trajnimin.

Përcaktimi i ciklit të vlerësimit

Funksioni make_predictions është cikli juaj i vlerësimit. Cikli i vlerësimit është mjaft i drejtpërdrejtë me një ndryshim të dukshëm. Nëse e mbani mend nga fillimi i fletores, keni vetëm 90 shembuj në grupin tuaj të të dhënave të trajnimit. Ky është një numër shumë i vogël shembujsh trajnimi dhe modeli juaj përfundon duke mos pasur shembuj të mjaftueshëm për madhësinë e grupit kur ekzekutoni trajnimin. Kjo do të thotë që në ciklin e vlerësimit, duhet ta mbushni grupin duke përsëritur shembujt.

Për t'u siguruar që cikli juaj i vlerësimit numëron vetëm shembujt aktualë dhe jo shembujt e mbushur, duhet të aplikoni një maskë te shembujt e mbushur që i përjashton ato nga rezultati.

# The main update_fn using a simple stochastic gradient descent (SGD).
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
  imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]

  def loss_fn(params):
    text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
    logp = jax.nn.log_softmax(text_logits, axis=-1)

    # The model takes as input txts[:, :-1] but the loss is defined as predicting
    # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
    # are part of the loss (e.g. prefix and padded tokens are not included).
    mask_loss = batch["mask_loss"][:, 1:]
    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])

    # Compute the loss per example. i.e. the mean of per token pplx.
    # Since each example has a different number of tokens, normalize it.
    token_pplx = jnp.sum(logp * targets, axis=-1)  # sum across vocab_size.
    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)  # sum across seq_len.
    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)  # weight by num of tokens.

    # batch_loss: mean of per example loss.
    return jnp.mean(example_loss)

  loss, grads = jax.value_and_grad(loss_fn)(params)

  # Apply gradients to trainable params using SGD.
  def apply_grad(param, gradient, trainable):
    if not trainable: return param
    return param - learning_rate * gradient

  params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)

  return params, loss

# Evaluation/inference loop.
def make_predictions(data_iterator, *, num_examples=None,
                     batch_size=4, seqlen=SEQLEN, sampler="greedy"):
  outputs = []
  while True:
    # Construct a list of examples in the batch.
    examples = []
    try:
      for _ in range(batch_size):
        examples.append(next(data_iterator))
        examples[-1]["_mask"] = np.array(True)  # Indicates true example.
    except StopIteration:
      if len(examples) == 0:
        return outputs

    # Not enough examples to complete a batch. Pad by repeating last example.
    while len(examples) % batch_size:
      examples.append(dict(examples[-1]))
      examples[-1]["_mask"] = np.array(False)  # Indicates padding example.

    # Convert list of examples into a dict of np.arrays and load onto devices.
    batch = jax.tree.map(lambda *x: np.stack(x), *examples)
    batch = big_vision.utils.reshard(batch, data_sharding)

    # Make model predictions
    tokens = decode({"params": params}, batch=batch,
                    max_decode_len=seqlen, sampler=sampler)

    # Fetch model predictions to device and detokenize.
    tokens, mask = jax.device_get((tokens, batch["_mask"]))
    tokens = tokens[mask]  # remove padding examples.
    responses = [postprocess_tokens(t) for t in tokens]

    # Append to html output.
    for example, response in zip(examples, responses):
      outputs.append((example["image"], response))
      if num_examples and len(outputs) >= num_examples:
        return outputs

Akordoni modelin

Tani që i keni konfiguruar të gjitha dhe i keni parë të dhënat e trajnimit, është koha për ta akorduar më në fund modelin. Kodi më poshtë ekzekuton ciklin e trajnimit për modelin për 64 hapa dhe printon shkallën e të mësuarit ( lr në rezultatin e printuar) dhe shkallën e humbjes për secilin hap.

Çdo 16 hapa, modeli printon parashikimet e tij në atë hap të trajnimit. Ky kod printon parashikime për të njëjtin grup imazhesh në mënyrë që të shihni se si aftësia e modelit për të parashikuar përshkrimet përmirësohet me kalimin e kohës.

Në hapat e mëparshëm të trajnimit, ka të ngjarë të ketë probleme me përshkrimet, të tilla si fjali të përsëritura ndërsa modeli ngec në ciklin e tij parashikues ose fjali të papërfunduara. Parashikimet e modelit bëhen gjithnjë e më të sakta ndërsa trajnimi përparon. Deri në hapin 64, parashikimet e modelit duhet të ngjajnë shumë me përshkrimet e ofruara nga të dhënat e trajnimit.

Ky proces zgjat rreth 15 minuta për t'u përfunduar në TPU-të T4.

# Run a short training loop with cosine learning rate schedule.
#
# Note: the first step can be quite slow on some machines (up to several minutes)
# due to XLA compilation of the jax.jit'd function.
#
%%time

BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.03

TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 4

train_data_it = train_data_iterator()

sched_fn = big_vision.utils.create_learning_rate_schedule(
    total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
    decay_type="cosine", warmup_percent=0.10)

for step in range(1, TRAIN_STEPS+1):
  # Make list of N training examples.
  examples = [next(train_data_it) for _ in range(BATCH_SIZE)]

  # Convert list of examples into a dict of np.arrays and load onto devices.
  batch = jax.tree.map(lambda *x: np.stack(x), *examples)
  batch = big_vision.utils.reshard(batch, data_sharding)

  # Training step and report training loss
  learning_rate = sched_fn(step)
  params, loss = update_fn(params, batch, learning_rate)

  loss = jax.device_get(loss)
  print(f"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}")

  if (step % EVAL_STEPS) == 0:
    print(f"Model predictions at step {step}")
    html_out = ""
    for image, caption in make_predictions(
        validation_data_iterator(), num_examples=4, batch_size=4):
      html_out += render_example(image, caption)
    display(HTML(html_out))
step:  1/64   lr: 0.00500   loss: 3.6567
step:  2/64   lr: 0.01000   loss: 1.9762
step:  3/64   lr: 0.01500   loss: 1.6299
step:  4/64   lr: 0.02000   loss: 1.5651
step:  5/64   lr: 0.02500   loss: 1.9813
step:  6/64   lr: 0.03000   loss: 1.9996
step:  7/64   lr: 0.02998   loss: 1.8595
step:  8/64   lr: 0.02992   loss: 1.6479
step:  9/64   lr: 0.02981   loss: 1.3693
step: 10/64   lr: 0.02966   loss: 1.3423
step: 11/64   lr: 0.02947   loss: 1.2122
step: 12/64   lr: 0.02924   loss: 1.0602
step: 13/64   lr: 0.02897   loss: 1.1314
step: 14/64   lr: 0.02866   loss: 1.2612
step: 15/64   lr: 0.02831   loss: 1.0132
step: 16/64   lr: 0.02792   loss: 1.2126
Model predictions at step 16
step: 17/64   lr: 0.02750   loss: 1.0986
step: 18/64   lr: 0.02704   loss: 0.9461
step: 19/64   lr: 0.02655   loss: 1.2098
step: 20/64   lr: 0.02602   loss: 1.0513
step: 21/64   lr: 0.02546   loss: 1.0979
step: 22/64   lr: 0.02488   loss: 0.9739
step: 23/64   lr: 0.02426   loss: 0.9589
step: 24/64   lr: 0.02362   loss: 0.7053
step: 25/64   lr: 0.02296   loss: 0.7347
step: 26/64   lr: 0.02227   loss: 0.6990
step: 27/64   lr: 0.02156   loss: 0.6736
step: 28/64   lr: 0.02083   loss: 0.6642
step: 29/64   lr: 0.02009   loss: 0.6908
step: 30/64   lr: 0.01933   loss: 0.7257
step: 31/64   lr: 0.01856   loss: 0.6902
step: 32/64   lr: 0.01778   loss: 0.7054
Model predictions at step 32
step: 33/64   lr: 0.01699   loss: 0.7709
step: 34/64   lr: 0.01620   loss: 0.6653
step: 35/64   lr: 0.01540   loss: 0.3811
step: 36/64   lr: 0.01460   loss: 0.3104
step: 37/64   lr: 0.01380   loss: 0.4042
step: 38/64   lr: 0.01301   loss: 0.3904
step: 39/64   lr: 0.01222   loss: 0.3339
step: 40/64   lr: 0.01144   loss: 0.4156
step: 41/64   lr: 0.01067   loss: 0.4085
step: 42/64   lr: 0.00991   loss: 0.3083
step: 43/64   lr: 0.00917   loss: 0.3757
step: 44/64   lr: 0.00844   loss: 0.3813
step: 45/64   lr: 0.00773   loss: 0.3381
step: 46/64   lr: 0.00704   loss: 0.2057
step: 47/64   lr: 0.00638   loss: 0.1287
step: 48/64   lr: 0.00574   loss: 0.1711
Model predictions at step 48
step: 49/64   lr: 0.00512   loss: 0.1183
step: 50/64   lr: 0.00454   loss: 0.1154
step: 51/64   lr: 0.00398   loss: 0.1967
step: 52/64   lr: 0.00345   loss: 0.1497
step: 53/64   lr: 0.00296   loss: 0.1688
step: 54/64   lr: 0.00250   loss: 0.1878
step: 55/64   lr: 0.00208   loss: 0.1865
step: 56/64   lr: 0.00169   loss: 0.1655
step: 57/64   lr: 0.00134   loss: 0.0911
step: 58/64   lr: 0.00103   loss: 0.1836
step: 59/64   lr: 0.00076   loss: 0.1242
step: 60/64   lr: 0.00053   loss: 0.0814
step: 61/64   lr: 0.00034   loss: 0.0866
step: 62/64   lr: 0.00019   loss: 0.1295
step: 63/64   lr: 0.00008   loss: 0.1053
step: 64/64   lr: 0.00002   loss: 0.0730
Model predictions at step 64
CPU times: user 2min 18s, sys: 8.98 s, total: 2min 27s
Wall time: 15min 45s

Prodhimi

Të dhënat e validimit për këtë fletore shënimesh përbëhen vetëm nga 10 imazhe. Në kodin normal, ka të ngjarë të keni shumë më tepër pika të dhënash për validim, por për këtë fletore shënimesh, ekzekutoni kodin e mëposhtëm për të gjeneruar përshkrime për të 10 imazhet. Pas akordimit të modelit, këto përshkrime duhet të jenë shumë të ngjashme në formë dhe mbulim përmbajtjeje me përshkrimet e përfshira me të dhënat e trajnimit që keni parë më parë në këtë fletore shënimesh.

Ekzekutoni kodin më poshtë për të gjeneruar përshkrime për grupin e të dhënave të validimit.

# The validation data consists of 10 images in a different domain than training
# data.
%%time

print("Model predictions")
html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
  html_out += render_example(image, caption)
display(HTML(html_out))
Model predictions
CPU times: user 1.87 s, sys: 283 ms, total: 2.15 s
Wall time: 39.3 s