| | Ekzekuto në Google Colab | | | Shiko burimin në GitHub |
Ky fletore tregon se si të akordoni imët PaliGemma në një detyrë të gjuhës vizuale me JAX . Akordimi i imët është një proces që mund të përmirësojë performancën e modelit tuaj në detyra specifike ose të ndihmojë modelin t'i përmbahet kërkesave specifike të daljes kur udhëzimet nuk janë të mjaftueshme dhe ju keni një sërë shembujsh që demonstrojnë daljet që dëshironi. Modelet e bazuara në Gemma si PaliGemma kërkojnë akordim të imët për të prodhuar rezultatet e pritura.
Çfarë ka në këtë fletore
Ky fletore përdor implementimin e referencës së modelit nga big_vision dhe tregon se si të:
- Instaloni varësitë dhe shkarkoni pikën e kontrollit të modelit PaliGemma dhe të dhënat e trajnimit
- Ngarko modelin në pajisjet GPU
- Përgatitni të dhënat hyrëse të modelit për trajnim dhe nxjerrje përfundimesh
- Përmirësoni modelin
- Inspektoni rezultatin
Të dhënat e trajnimit për këtë fletore përbëhen nga 90 çifte imazhesh dhe mbishkrime të gjata që i përshkruajnë ato. Për ta bërë të ekzekutueshëm në një kohë ekzekutimi T4 colab, do të rregulloni vetëm shtresat e vëmendjes së modelit gjuhësor dhe do të ngrini parametrat e tjerë.
Ky shembull është vetëm për qëllime mësimi. Në një rast përdorimi real, sasia e të dhënave, parametrat e trajnueshëm, hapat e trajnimit dhe hiperparametrat, si dhe rezultatet e përftuara mund të jenë dukshëm të ndryshme.
Para se të filloni
Para se të lexoni këtë fletore, duhet të jeni të njohur me kodin Python, si dhe me mënyrën se si trajnohen modelet e mëdha gjuhësore (LLM). Nuk keni nevojë të jeni të njohur me JAX, por njohuritë bazë rreth JAX (ose teknologjive të ngjashme si Keras) janë të dobishme kur lexoni kodin shembullor.
Konfigurimi
Seksionet e mëposhtme shpjegojnë hapat paraprakë për ta bërë një fletore kompjuterike të përdorë një model PaliGemma, duke përfshirë aksesin në model, marrjen e një çelësi API dhe konfigurimin e kohës së ekzekutimit të fletores.
Merrni qasje në PaliGemma
Para se të përdorni PaliGemma për herë të parë, duhet të kërkoni qasje në model përmes Kaggle duke përfunduar hapat e mëposhtëm:
- Kyçu në Kaggle ose krijo një llogari të re në Kaggle nëse nuk ke një të tillë.
- Shko te karta e modelit PaliGemma dhe kliko te Kërkesë për Qasje .
- Plotësoni formularin e pëlqimit dhe pranoni termat dhe kushtet.
Konfiguro çelësin tënd API
Për të përdorur PaliGemma, duhet të jepni emrin tuaj të përdoruesit Kaggle dhe një çelës API të Kaggle.
Për të gjeneruar një çelës API të Kaggle, hapni faqen e Cilësimeve në Kaggle dhe klikoni Krijo Token të Ri . Kjo aktivizon shkarkimin e një skedari kaggle.json që përmban kredencialet tuaja të API-t.
Pastaj, në Colab, zgjidhni Sekretet (🔑) në panelin e majtë dhe shtoni emrin e përdoruesit dhe çelësin API të Kaggle. Ruajeni emrin e përdoruesit nën emrin KAGGLE_USERNAME dhe çelësin API nën emrin KAGGLE_KEY .
Zgjidhni kohën e ekzekutimit
Për të përfunduar këtë tutorial, do t'ju duhet të keni një Colab runtime me burime të mjaftueshme për të ekzekutuar modelin PaliGemma. Në këtë rast, mund të përdorni një GPU T4:
- Në pjesën e sipërme djathtas të dritares Colab, klikoni në menynë zbritëse ▾ (Opsione shtesë lidhjeje) .
- Zgjidhni Ndrysho llojin e kohës së ekzekutimit .
- Nën Përshpejtuesin e Pajisjeve , zgjidhni GPU-në T4 .
Instaloni paketat Python
Ekzekutoni qelizën më poshtë për të instaluar KaggleHub.
pip install -U -q kagglehubVendos variablat e mjedisit
Vendosni variablat e mjedisit dhe hyrjen në Kaggle.
import os
# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid out-of-memory due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
# This will prompt you to enter your Kaggle API token
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle… Kaggle credentials set. Kaggle credentials successfully validated.
Merrni repozitorin big_vision dhe instaloni varësitë përkatëse
Shkarkoni repozitorin big_vision në fletoren tuaj të shënimeve Colab nga GitHub dhe instaloni varësitë që lidhen me big_vision duke ekzekutuar kodin e mëposhtëm.
import os
import sys
# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
raise "It seems you are using Colab with remote TPUs which is not supported."
# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
!git clone --quiet --branch=main --depth=1 \
https://github.com/google-research/big_vision big_vision_repo
# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
sys.path.append("big_vision_repo")
# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.7/76.7 kB 2.8 MB/s eta 0:00:00
Importo JAX dhe varësi të tjera
Importoni JAX dhe varësi të tjera të nevojshme për PaliGemma, si TensorFlow dhe NumPy.
import base64
import functools
import html
import io
import os
import warnings
import jax
import jax.numpy as jnp
import numpy as np
import ml_collections
import tensorflow as tf
import sentencepiece
from IPython.core.display import display, HTML
from PIL import Image
# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding
# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")
backend = jax.extend.backend.get_backend()
print(f"JAX version: {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices: {jax.device_count()}")
JAX version: 0.7.2 JAX platform: gpu JAX devices: 1
Shkarkoni dhe konfiguroni modelin
Në këtë hap, do të shkarkoni pikën e kontrollit të modelit dhe do ta konfiguroni atë në mënyrë që ta konfiguroni më vonë. Ky hap ju tregon se si të zhvendosni parametrat e modelit në memorien TPU, gjë që është e dobishme për akordimin e imët të modeleve në pajisje me burime të kufizuara.
Shkarkoni pikën e kontrollit të modelit
PaliGemma përfshin disa variacione modeli. Për këtë tutorial, do të përdorni modelin bazë të peshës JAX/FLAX PaliGemma 3B .
Shkarkoni pikën e kontrollit të modelit nga Kaggle duke ekzekutuar kodin e mëposhtëm. Ky proces zgjat disa minuta për t'u përfunduar.
import os
import kagglehub
# Use these for PaliGemma-2 3B 224px²
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "./paligemma2-3b-pt-224.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224" # Path to fetch from Kaggle.
# Use these for PaliGemma 1:
# LLM_VARIANT = "gemma_2b"
# MODEL_PATH = "./paligemma-3b-pt-224.f16.npz"
# KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"
if not os.path.exists(MODEL_PATH):
print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
print(f"Model path: {MODEL_PATH}")
TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
print("Downloading the model tokenizer...")
!gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
print(f"Tokenizer path: {TOKENIZER_PATH}")
DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
print("Downloading the dataset...")
!gsutil -m -q cp -n -r gs://longcap100/ .
print(f"Data path: {DATA_DIR}")
Downloading the checkpoint from Kaggle, this could take a few minutes.... Downloading to /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz... 100%|██████████| 5.65G/5.65G [00:54<00:00, 112MB/s] Model path: /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz Downloading the model tokenizer... Copying gs://big_vision/paligemma_tokenizer.model... - [1 files][ 4.1 MiB/ 4.1 MiB] Operation completed over 1 objects/4.1 MiB. Tokenizer path: ./paligemma_tokenizer.model Downloading the dataset... Data path: ./longcap100
Konfiguro modelin
Është koha për të filluar konfigurimin e modelit që do të përdorni.
Për këtë laptop, duhet të jeni në gjendje ta përshtatni modelin tuaj në një GPU T4. Duke pasur burime të kufizuara, si kufizime hapësinore, duhet të jeni të kujdesshëm se si është konfiguruar modeli juaj.
Nëse i rregulloni imët çdo parametër, modeli juaj nuk do të jetë në gjendje të funksionojë në mjedisin e fletores. Si rezultat, në këtë pjesë të fletores, do ta konfiguroni modelin tuaj në mënyrë që të ketë aftësinë të ngrijë disa nga parametrat dhe të rregullojë vetëm parametrat që duhet të rregullohen vërtet që modeli t'ju japë rezultate të sakta. Në LLM-të, parametrat thuhet se janë të ngrirë kur nuk përdoren më në mënyrë aktive për të trajnuar modelin.
Për të konfiguruar modelin tuaj, duhet të:
- Inicializoni
model_configsi njëFrozenConfigDictnë mënyrë që të ngrini disa nga parametrat dhe të mbani përdorimin e memories të ulët. - Inicializoni një instancë të klasës PaliGemma
Modelduke përdorurmodel_configsi konfigurime të saj. - Ngarko parametrat e modelit në RAM
- Përcaktoni një funksion
decodepër të marrë mostra nga rezultatet e modelit
Ky kod në këtë qelizë ekzekutohet për rreth një minutë.
# Define model
# IMPORTANT: Gemma-2 has a "final_logits_softcap" property. Set it to 0.0
# for better transfer results.
model_config = ml_collections.FrozenConfigDict({
"llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
"img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)
# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)
# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())
Zhvendos parametrat e modelit në memorien GPU/TPU
Tani duhet të zhvendosni parametrat e modelit në memorien GPU/TPU. Së pari, ndani parametrat në të gjitha GPU-të e disponueshme dhe më pas ngarkoni parametrat. Këtu, do t'i ngarkoni parametrat në mënyrë sekuenciale. Ky proces zgjat më shumë sesa ngarkimi i tyre njëkohësisht, por kërkon më shumë RAM sesa keni në dispozicion në këtë laptop.
Së fundmi, printoni të gjithë parametrat për të parë se në çfarë lloji është paraqitur secili parametër individual. Parametrat e ngrirë mbahen si float16 , ndërsa parametrat e trajnueshëm paraqiten në float32 . Kur ta inspektoni listën, do të shihni se shumica e parametrave janë ngrirë dhe janë float16 .
# Create a pytree mask of the trainable params.
def is_trainable_param(name, param): # pylint: disable=unused-argument
if name.startswith("llm/layers/attn/"): return True
if name.startswith("llm/"): return False
if name.startswith("img/"): return False
raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))
data_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("data"))
params_sharding = big_vision.sharding.infer_sharding(
params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)
# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable")
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
# Cast others to float16, since some GPUs don't support bf16.
return jax.tree.map(lambda p, m: p.astype(jnp.float32)
if m else p.astype(jnp.float16),
params, trainable)
# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default.
# Instead, do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
params[idx] = big_vision.utils.reshard(params[idx], sharding)
params[idx] = maybe_cast_to_f32(params[idx], trainable)
params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)
# Print params to show what the model is made of.
def parameter_overview(params):
for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")
print(" == Model params == ")
parameter_overview(params)
== Model params == img/Transformer/encoder_norm/bias (1152,) float16 img/Transformer/encoder_norm/scale (1152,) float16 img/Transformer/encoderblock/LayerNorm_0/bias (27, 1152) float16 img/Transformer/encoderblock/LayerNorm_0/scale (27, 1152) float16 img/Transformer/encoderblock/LayerNorm_1/bias (27, 1152) float16 img/Transformer/encoderblock/LayerNorm_1/scale (27, 1152) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias (27, 4304) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel (27, 1152, 4304) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias (27, 1152) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel (27, 4304, 1152) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias (27, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel (27, 1152, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias (27, 1152) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel (27, 16, 72, 1152) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias (27, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel (27, 1152, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias (27, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel (27, 1152, 16, 72) float16 img/embedding/bias (1152,) float16 img/embedding/kernel (14, 14, 3, 1152) float16 img/head/bias (2304,) float16 img/head/kernel (1152, 2304) float16 img/pos_embedding (1, 256, 1152) float16 llm/embedder/input_embedding (257152, 2304) float16 llm/final_norm/scale (2304,) float16 llm/layers/attn/attn_vec_einsum/w (26, 8, 256, 2304) float32 llm/layers/attn/kv_einsum/w (26, 2, 4, 2304, 256) float32 llm/layers/attn/q_einsum/w (26, 8, 2304, 256) float32 llm/layers/mlp/gating_einsum (26, 2, 2304, 9216) float16 llm/layers/mlp/linear (26, 9216, 2304) float16 llm/layers/post_attention_norm/scale (26, 2304) float16 llm/layers/post_ffw_norm/scale (26, 2304) float16 llm/layers/pre_attention_norm/scale (26, 2304) float16 llm/layers/pre_ffw_norm/scale (26, 2304) float16
Përgatituni për të akorduar modelin
Tani që modeli juaj është konfiguruar, mund ta konfiguroni atë. Në këtë hap, do të krijoni të dhënat hyrëse të modelit tuaj, si dhe iteratorët e trajnimit dhe validimit, do të shikoni shembujt e trajnimit dhe do të përcaktoni sythet e trajnimit dhe validimit.
Krijo hyrje modeli
Pika e kontrollit të modelit që po përdorni është trajnuar tashmë në imazhe me raporte të ndryshme aspektesh që janë ridimensionuar në 224x224 piksel, dhe për të trajtuar tekste të tokenizuara.
Kodi më poshtë përcakton tre funksione që do t'i përdorni në hapin tjetër për të krijuar të dhënat hyrëse të modelit:
-
preprocess_image: Normalizon të dhënat e imazhit. Në këtë rast, para-përpunimi e konverton imazhin e kaluar në shkallë gri, heq shtresën alfa dhe e ridimensionon imazhin e kaluar në madhësinë e kërkuar nga modeli për hyrjet e imazhit (224x224 piksel). -
preprocess_tokens: Ndan tokenët dhe shton flamuj për të shënuar nëse një token është një token me parashtesë apo prapashtesë. Këta flamuj do të përdoren më vonë në kod, gjatë hapit të trajnimit dhe ciklit të vlerësimit. -
postprocess_tokens: Heq çdo token të mbetur në dhe/ose pas tokenit të fundit të sekuencës (EOS) dhe kthen tokenët e mbetur të dekoduar.
def preprocess_image(image, size=224):
# Model has been trained to handle images of different aspects ratios
# resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
# options are helpful to improve quality in some tasks.
image = np.asarray(image)
if image.ndim == 2: # Convert image without last channel into greyscale.
image = np.stack((image,)*3, axis=-1)
image = image[..., :3] # Remove alpha layer.
assert image.shape[-1] == 3
image = tf.constant(image)
image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]
def preprocess_tokens(prefix, suffix=None, seqlen=None):
# Model has been trained to handle tokenized text composed of a prefix with
# full attention and a suffix with causal attention.
separator = "\n"
tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.
mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.
if suffix:
suffix = tokenizer.encode(suffix, add_eos=True)
tokens += suffix
mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.
mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.
mask_input = [1] * len(tokens) # 1 if it's a token, 0 if padding.
if seqlen:
padding = [0] * max(0, seqlen - len(tokens))
tokens = tokens[:seqlen] + padding
mask_ar = mask_ar[:seqlen] + padding
mask_loss = mask_loss[:seqlen] + padding
mask_input = mask_input[:seqlen] + padding
return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))
def postprocess_tokens(tokens):
tokens = tokens.tolist() # np.array to list[int]
try: # Remove tokens at and after EOS if any.
eos_pos = tokens.index(tokenizer.eos_id())
tokens = tokens[:eos_pos]
except ValueError:
pass
return tokenizer.decode(tokens)
Krijoni iteratorët e trajnimit dhe validimit
Krijoni dy iteratorë:
- Një iterator trajnimi për të lejuar që procesi i trajnimit të kalojë nëpër të dhënat në copa në vend që t'i përpunojë të gjitha menjëherë
- Kjo ju lejon të bëni disa përpunime paraprake të të dhënave para përdorimit.
- Një iterator validimi që lejon procesin e trajnimit të përsërisë mbi të dhënat e validimit për të parë se sa mirë modeli i akorduar përputhet me rezultatet e dhëna.
SEQLEN = 128
train_dataset = big_vision.datasets.jsonl.DataSource(
os.path.join(DATA_DIR, "data_train90.jsonl"),
fopen_keys={"image": DATA_DIR})
val_dataset = big_vision.datasets.jsonl.DataSource(
os.path.join(DATA_DIR, "data_val10.jsonl"),
fopen_keys={"image": DATA_DIR})
def train_data_iterator():
"""Never ending iterator over training examples."""
# Shuffle examples and repeat so one can train for many epochs.
dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()
for example in dataset.as_numpy_iterator():
image = Image.open(io.BytesIO(example["image"]))
image = preprocess_image(image)
prefix = "caption en" # Could also be a different prefix per example.
suffix = example["suffix"].decode().lower()
tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_loss": np.asarray(mask_loss),
}
def validation_data_iterator():
"""Single iterator over validation examples."""
for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():
image = Image.open(io.BytesIO(example["image"]))
image = preprocess_image(image)
prefix = "caption en" # Could also be a different prefix per example.
tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_input": np.asarray(mask_input),
}
Shikoni shembuj trajnimi
Në këtë fletore, të dhënat e trajnimit përmbajnë 90 imazhe që shoqërohen me përshkrime të gjata të asaj që përshkruhet në imazh.
Kodi më poshtë printon një përzgjedhje të rastësishme imazhesh me përshkrimet e tyre nga grupi i të dhënave të trajnimit, në mënyrë që të shihni se si duken imazhet dhe përshkrimet mbi të cilat është trajnuar modeli juaj. Çdo imazh shfaqet si një JPEG me 128x128 piksel, me përshkrimin e printuar pranë imazhit në të djathtë.
def render_inline(image, resize=(128, 128)):
"""Convert image into inline html."""
image = Image.fromarray(image)
image.resize(resize)
with io.BytesIO() as buffer:
image.save(buffer, format='jpeg')
image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
return f"data:image/jpeg;base64,{image_b64}"
def render_example(image, caption):
image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]
return f"""
<div style="display: inline-flex; align-items: center; justify-content: center;">
<img style="width:128px; height:128px;" src="{render_inline(image, resize=(64,64))}" />
<p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
</div>
"""
html_out = ""
for idx, example in zip(range(8), train_data_iterator()):
caption = postprocess_tokens(example["text"]) # detokenize model input.
caption = caption[len("caption en\n"):] # strip prefix
html_out += render_example(example["image"], caption)
print("Training examples")
display(HTML(html_out))
Training examples
Përcaktoni ciklet e trajnimit dhe vlerësimit
Përcaktoni ciklin e trajnimit për të trajnuar modelin në të dhënat e dhëna dhe ciklin e vlerësimit për të parë të gjithë shembujt në të dhënat e validimit dhe për të bërë parashikimet e tij.
Përcaktimi i ciklit të trajnimit
Funksioni update_fn përcakton hapin e trajnimit. Gjatë hapit të trajnimit, humbja për shembull llogaritet dhe zbritja stokastike e gradientit (SGD) zbatohet në parametrat e trajnueshëm.
Kujtoni se më herët në fletore, keni përfshirë flamuj në funksionin preprocess_tokens që përfshinin mask_loss . Do ta përdorni flamurin mask_loss këtu për të përjashtuar prefiksin dhe tokenët e mbushur nga humbja. Pa të, llogaritja e humbjes do të jetë e shtrembër. Gjithashtu duhet të normalizoni çdo shembull, pasi secili prej tyre ka një numër të ndryshëm tokenësh. Pasi prefiksi dhe tokenët e mbushur të jenë përjashtuar dhe shembujt të jenë normalizuar, mund të llogaritni humbjen për shembull.
Hapi i trajnimit përfshin gjithashtu një funksion për të aplikuar një SGD për të optimizuar trajnimin.
Përcaktimi i ciklit të vlerësimit
Funksioni make_predictions është cikli juaj i vlerësimit. Cikli i vlerësimit është mjaft i drejtpërdrejtë me një ndryshim të dukshëm. Nëse e mbani mend nga fillimi i fletores, keni vetëm 90 shembuj në grupin tuaj të të dhënave të trajnimit. Ky është një numër shumë i vogël shembujsh trajnimi dhe modeli juaj përfundon duke mos pasur shembuj të mjaftueshëm për madhësinë e grupit kur ekzekutoni trajnimin. Kjo do të thotë që në ciklin e vlerësimit, duhet ta mbushni grupin duke përsëritur shembujt.
Për t'u siguruar që cikli juaj i vlerësimit numëron vetëm shembujt aktualë dhe jo shembujt e mbushur, duhet të aplikoni një maskë te shembujt e mbushur që i përjashton ato nga rezultati.
# The main update_fn using a simple stochastic gradient descent (SGD).
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]
def loss_fn(params):
text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
logp = jax.nn.log_softmax(text_logits, axis=-1)
# The model takes as input txts[:, :-1] but the loss is defined as predicting
# next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
# are part of the loss (e.g. prefix and padded tokens are not included).
mask_loss = batch["mask_loss"][:, 1:]
targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])
# Compute the loss per example. i.e. the mean of per token pplx.
# Since each example has a different number of tokens, normalize it.
token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.
example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.
example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.
# batch_loss: mean of per example loss.
return jnp.mean(example_loss)
loss, grads = jax.value_and_grad(loss_fn)(params)
# Apply gradients to trainable params using SGD.
def apply_grad(param, gradient, trainable):
if not trainable: return param
return param - learning_rate * gradient
params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)
return params, loss
# Evaluation/inference loop.
def make_predictions(data_iterator, *, num_examples=None,
batch_size=4, seqlen=SEQLEN, sampler="greedy"):
outputs = []
while True:
# Construct a list of examples in the batch.
examples = []
try:
for _ in range(batch_size):
examples.append(next(data_iterator))
examples[-1]["_mask"] = np.array(True) # Indicates true example.
except StopIteration:
if len(examples) == 0:
return outputs
# Not enough examples to complete a batch. Pad by repeating last example.
while len(examples) % batch_size:
examples.append(dict(examples[-1]))
examples[-1]["_mask"] = np.array(False) # Indicates padding example.
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Make model predictions
tokens = decode({"params": params}, batch=batch,
max_decode_len=seqlen, sampler=sampler)
# Fetch model predictions to device and detokenize.
tokens, mask = jax.device_get((tokens, batch["_mask"]))
tokens = tokens[mask] # remove padding examples.
responses = [postprocess_tokens(t) for t in tokens]
# Append to html output.
for example, response in zip(examples, responses):
outputs.append((example["image"], response))
if num_examples and len(outputs) >= num_examples:
return outputs
Akordoni modelin
Tani që i keni konfiguruar të gjitha dhe i keni parë të dhënat e trajnimit, është koha për ta akorduar më në fund modelin. Kodi më poshtë ekzekuton ciklin e trajnimit për modelin për 64 hapa dhe printon shkallën e të mësuarit ( lr në rezultatin e printuar) dhe shkallën e humbjes për secilin hap.
Çdo 16 hapa, modeli printon parashikimet e tij në atë hap të trajnimit. Ky kod printon parashikime për të njëjtin grup imazhesh në mënyrë që të shihni se si aftësia e modelit për të parashikuar përshkrimet përmirësohet me kalimin e kohës.
Në hapat e mëparshëm të trajnimit, ka të ngjarë të ketë probleme me përshkrimet, të tilla si fjali të përsëritura ndërsa modeli ngec në ciklin e tij parashikues ose fjali të papërfunduara. Parashikimet e modelit bëhen gjithnjë e më të sakta ndërsa trajnimi përparon. Deri në hapin 64, parashikimet e modelit duhet të ngjajnë shumë me përshkrimet e ofruara nga të dhënat e trajnimit.
Ky proces zgjat rreth 15 minuta për t'u përfunduar në TPU-të T4.
# Run a short training loop with cosine learning rate schedule.
#
# Note: the first step can be quite slow on some machines (up to several minutes)
# due to XLA compilation of the jax.jit'd function.
#
%%time
BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.03
TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 4
train_data_it = train_data_iterator()
sched_fn = big_vision.utils.create_learning_rate_schedule(
total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
decay_type="cosine", warmup_percent=0.10)
for step in range(1, TRAIN_STEPS+1):
# Make list of N training examples.
examples = [next(train_data_it) for _ in range(BATCH_SIZE)]
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Training step and report training loss
learning_rate = sched_fn(step)
params, loss = update_fn(params, batch, learning_rate)
loss = jax.device_get(loss)
print(f"step: {step:2d}/{TRAIN_STEPS:2d} lr: {learning_rate:.5f} loss: {loss:.4f}")
if (step % EVAL_STEPS) == 0:
print(f"Model predictions at step {step}")
html_out = ""
for image, caption in make_predictions(
validation_data_iterator(), num_examples=4, batch_size=4):
html_out += render_example(image, caption)
display(HTML(html_out))
step: 1/64 lr: 0.00500 loss: 3.6567 step: 2/64 lr: 0.01000 loss: 1.9762 step: 3/64 lr: 0.01500 loss: 1.6299 step: 4/64 lr: 0.02000 loss: 1.5651 step: 5/64 lr: 0.02500 loss: 1.9813 step: 6/64 lr: 0.03000 loss: 1.9996 step: 7/64 lr: 0.02998 loss: 1.8595 step: 8/64 lr: 0.02992 loss: 1.6479 step: 9/64 lr: 0.02981 loss: 1.3693 step: 10/64 lr: 0.02966 loss: 1.3423 step: 11/64 lr: 0.02947 loss: 1.2122 step: 12/64 lr: 0.02924 loss: 1.0602 step: 13/64 lr: 0.02897 loss: 1.1314 step: 14/64 lr: 0.02866 loss: 1.2612 step: 15/64 lr: 0.02831 loss: 1.0132 step: 16/64 lr: 0.02792 loss: 1.2126 Model predictions at step 16
step: 17/64 lr: 0.02750 loss: 1.0986 step: 18/64 lr: 0.02704 loss: 0.9461 step: 19/64 lr: 0.02655 loss: 1.2098 step: 20/64 lr: 0.02602 loss: 1.0513 step: 21/64 lr: 0.02546 loss: 1.0979 step: 22/64 lr: 0.02488 loss: 0.9739 step: 23/64 lr: 0.02426 loss: 0.9589 step: 24/64 lr: 0.02362 loss: 0.7053 step: 25/64 lr: 0.02296 loss: 0.7347 step: 26/64 lr: 0.02227 loss: 0.6990 step: 27/64 lr: 0.02156 loss: 0.6736 step: 28/64 lr: 0.02083 loss: 0.6642 step: 29/64 lr: 0.02009 loss: 0.6908 step: 30/64 lr: 0.01933 loss: 0.7257 step: 31/64 lr: 0.01856 loss: 0.6902 step: 32/64 lr: 0.01778 loss: 0.7054 Model predictions at step 32
step: 33/64 lr: 0.01699 loss: 0.7709 step: 34/64 lr: 0.01620 loss: 0.6653 step: 35/64 lr: 0.01540 loss: 0.3811 step: 36/64 lr: 0.01460 loss: 0.3104 step: 37/64 lr: 0.01380 loss: 0.4042 step: 38/64 lr: 0.01301 loss: 0.3904 step: 39/64 lr: 0.01222 loss: 0.3339 step: 40/64 lr: 0.01144 loss: 0.4156 step: 41/64 lr: 0.01067 loss: 0.4085 step: 42/64 lr: 0.00991 loss: 0.3083 step: 43/64 lr: 0.00917 loss: 0.3757 step: 44/64 lr: 0.00844 loss: 0.3813 step: 45/64 lr: 0.00773 loss: 0.3381 step: 46/64 lr: 0.00704 loss: 0.2057 step: 47/64 lr: 0.00638 loss: 0.1287 step: 48/64 lr: 0.00574 loss: 0.1711 Model predictions at step 48
step: 49/64 lr: 0.00512 loss: 0.1183 step: 50/64 lr: 0.00454 loss: 0.1154 step: 51/64 lr: 0.00398 loss: 0.1967 step: 52/64 lr: 0.00345 loss: 0.1497 step: 53/64 lr: 0.00296 loss: 0.1688 step: 54/64 lr: 0.00250 loss: 0.1878 step: 55/64 lr: 0.00208 loss: 0.1865 step: 56/64 lr: 0.00169 loss: 0.1655 step: 57/64 lr: 0.00134 loss: 0.0911 step: 58/64 lr: 0.00103 loss: 0.1836 step: 59/64 lr: 0.00076 loss: 0.1242 step: 60/64 lr: 0.00053 loss: 0.0814 step: 61/64 lr: 0.00034 loss: 0.0866 step: 62/64 lr: 0.00019 loss: 0.1295 step: 63/64 lr: 0.00008 loss: 0.1053 step: 64/64 lr: 0.00002 loss: 0.0730 Model predictions at step 64
CPU times: user 2min 18s, sys: 8.98 s, total: 2min 27s Wall time: 15min 45s
Prodhimi
Të dhënat e validimit për këtë fletore shënimesh përbëhen vetëm nga 10 imazhe. Në kodin normal, ka të ngjarë të keni shumë më tepër pika të dhënash për validim, por për këtë fletore shënimesh, ekzekutoni kodin e mëposhtëm për të gjeneruar përshkrime për të 10 imazhet. Pas akordimit të modelit, këto përshkrime duhet të jenë shumë të ngjashme në formë dhe mbulim përmbajtjeje me përshkrimet e përfshira me të dhënat e trajnimit që keni parë më parë në këtë fletore shënimesh.
Ekzekutoni kodin më poshtë për të gjeneruar përshkrime për grupin e të dhënave të validimit.
# The validation data consists of 10 images in a different domain than training
# data.
%%time
print("Model predictions")
html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
html_out += render_example(image, caption)
display(HTML(html_out))
Model predictions
CPU times: user 1.87 s, sys: 283 ms, total: 2.15 s Wall time: 39.3 s
Ekzekuto në Google Colab
Shiko burimin në GitHub