تنظیم دقیق PaliGemma با JAX و Flax

مشاهده در ai.google.dev در گوگل کولب اجرا کنید دویدن در کاگل باز کردن در Vertex AI مشاهده منبع در گیت‌هاب

این دفترچه یادداشت نحوه تنظیم دقیق PaliGemma را در یک کار زبان بینایی با JAX نشان می‌دهد. تنظیم دقیق فرآیندی است که می‌تواند عملکرد مدل شما را در وظایف خاص بهبود بخشد یا به مدل کمک کند تا در صورت کافی نبودن دستورالعمل‌ها و داشتن مجموعه‌ای از مثال‌ها که خروجی‌های مورد نظر شما را نشان می‌دهند، به الزامات خروجی خاص پایبند باشد. مدل‌های مبتنی بر Gemma مانند PaliGemma برای تولید نتایج مورد انتظار نیاز به تنظیم دقیق دارند.

آنچه در این دفترچه یادداشت است

این دفترچه یادداشت از پیاده‌سازی مرجع مدل از big_vision استفاده می‌کند و نحوه‌ی انجام موارد زیر را نشان می‌دهد:

  • وابستگی‌ها را نصب کنید و ایست بازرسی مدل PaliGemma و داده‌های آموزشی را دانلود کنید
  • مدل را روی دستگاه‌های GPU بارگذاری کنید
  • آماده‌سازی ورودی‌های مدل برای آموزش و استنتاج
  • مدل را دقیق تنظیم کنید
  • خروجی را بررسی کنید

داده‌های آموزشی این دفترچه یادداشت شامل ۹۰ جفت تصویر و توضیحات طولانی است که آنها را توصیف می‌کند. برای اینکه بتوان آن را در یک محیط برنامه‌نویسی T4 اجرا کرد، فقط لایه‌های توجه مدل زبان را تنظیم دقیق کرده و سایر پارامترها را ثابت نگه می‌دارید.

این مثال فقط برای اهداف یادگیری است. در یک مورد استفاده واقعی، مقدار داده‌ها، پارامترهای قابل آموزش، مراحل آموزش و ابرپارامترها و نتایج به‌دست‌آمده می‌توانند به‌طور قابل‌توجهی متفاوت باشند.

قبل از اینکه شروع کنی

قبل از مطالعه‌ی این دفترچه، باید با کد پایتون و همچنین نحوه‌ی آموزش مدل‌های زبانی بزرگ (LLM) آشنا باشید. نیازی به آشنایی با JAX نیست، اما دانش اولیه در مورد JAX (یا فناوری‌های مشابه مانند Keras) هنگام خواندن کد نمونه مفید است.

راه‌اندازی

بخش‌های زیر مراحل اولیه برای استفاده از مدل PaliGemma در یک نوت‌بوک، از جمله دسترسی به مدل، دریافت کلید API و پیکربندی زمان اجرای نوت‌بوک را توضیح می‌دهند.

به PaliGemma دسترسی پیدا کنید

قبل از اولین استفاده از PaliGemma، باید با انجام مراحل زیر، از طریق Kaggle درخواست دسترسی به مدل را بدهید:

  1. وارد حساب کاگل شوید، یا اگر از قبل حساب کاگل ندارید، یک حساب کاگل جدید ایجاد کنید.
  2. به کارت مدل PaliGemma بروید و روی درخواست دسترسی کلیک کنید.
  3. فرم رضایت‌نامه را تکمیل کنید و شرایط و ضوابط را بپذیرید.

کلید API خود را پیکربندی کنید

برای استفاده از PaliGemma، باید نام کاربری Kaggle و یک کلید API Kaggle خود را ارائه دهید.

برای تولید کلید API کاگل، صفحه تنظیمات خود را در کاگل باز کنید و روی ایجاد توکن جدید کلیک کنید. این کار باعث دانلود فایل kaggle.json حاوی اطلاعات احراز هویت API شما می‌شود.

سپس، در Colab، در پنل سمت چپ، گزینه Secrets (🔑) را انتخاب کنید و نام کاربری Kaggle و کلید API Kaggle خود را اضافه کنید. نام کاربری خود را با نام KAGGLE_USERNAME و کلید API خود را با نام KAGGLE_KEY ذخیره کنید.

زمان اجرا را انتخاب کنید

برای تکمیل این آموزش، به یک محیط اجرای Colab با منابع کافی برای اجرای مدل PaliGemma نیاز دارید. در این حالت، می‌توانید از یک پردازنده گرافیکی T4 استفاده کنید:

  1. در سمت راست بالای پنجره Colab، روی منوی کشویی ▾ (گزینه‌های اتصال اضافی) کلیک کنید.
  2. تغییر نوع زمان اجرا را انتخاب کنید.
  3. در قسمت شتاب‌دهنده سخت‌افزاری ، T4 GPU را انتخاب کنید.

نصب بسته‌های پایتون

برای نصب KaggleHub، سلول زیر را اجرا کنید.

pip install -U -q kagglehub

تنظیم متغیرهای محیطی

متغیرهای محیطی و ورود به سیستم Kaggle را تنظیم کنید.

import os
# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid out-of-memory due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

# This will prompt you to enter your Kaggle API token
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…
Kaggle credentials set.
Kaggle credentials successfully validated.

مخزن big_vision را از گیت‌هاب روی نوت‌بوک Colab خود دانلود کنید و با اجرای کد زیر، وابستگی‌های مربوط به big_vision را نصب کنید.

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.7/76.7 kB 2.8 MB/s eta 0:00:00

وارد کردن JAX و سایر وابستگی‌ها

JAX و سایر وابستگی‌های مورد نیاز برای PaliGemma، مانند TensorFlow و NumPy را وارد کنید.

import base64
import functools
import html
import io
import os
import warnings

import jax
import jax.numpy as jnp
import numpy as np
import ml_collections

import tensorflow as tf
import sentencepiece

from IPython.core.display import display, HTML
from PIL import Image

# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns

# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding

# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

backend = jax.extend.backend.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")
JAX version:  0.7.2
JAX platform: gpu
JAX devices:  1

مدل را دانلود و پیکربندی کنید

در این مرحله، شما نقطه کنترل مدل را دانلود و پیکربندی می‌کنید تا بتوانید بعداً آن را به طور دقیق تنظیم کنید. این مرحله به شما نشان می‌دهد که چگونه پارامترهای مدل را به حافظه TPU منتقل کنید، که برای تنظیم دقیق مدل‌ها در دستگاه‌هایی با منابع محدود مفید است.

دانلود نمونه چک پوینت

PaliGemma شامل چندین مدل مختلف است. برای این آموزش، از مدل وزن پایه JAX/FLAX PaliGemma 3B استفاده خواهید کرد.

با اجرای کد زیر، نقطه بررسی مدل را از Kaggle دانلود کنید. این فرآیند چند دقیقه طول می‌کشد تا تکمیل شود.

import os
import kagglehub

# Use these for PaliGemma-2 3B 224px²
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "./paligemma2-3b-pt-224.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224"  # Path to fetch from Kaggle.

# Use these for PaliGemma 1:
# LLM_VARIANT = "gemma_2b"
# MODEL_PATH = "./paligemma-3b-pt-224.f16.npz"
# KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"

if not os.path.exists(MODEL_PATH):
  print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
  MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
  print(f"Model path: {MODEL_PATH}")

TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")

DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
  print("Downloading the dataset...")
  !gsutil -m -q cp -n -r gs://longcap100/ .
  print(f"Data path: {DATA_DIR}")
Downloading the checkpoint from Kaggle, this could take a few minutes....
Downloading to /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz...
100%|██████████| 5.65G/5.65G [00:54<00:00, 112MB/s]
Model path: /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz
Downloading the model tokenizer...
Copying gs://big_vision/paligemma_tokenizer.model...

- [1 files][  4.1 MiB/  4.1 MiB]                                                
Operation completed over 1 objects/4.1 MiB.                                      
Tokenizer path: ./paligemma_tokenizer.model
Downloading the dataset...
Data path: ./longcap100

پیکربندی مدل

وقت آن است که عملاً پیکربندی مدلی را که قرار است استفاده کنید، شروع کنید.

برای این نوت‌بوک، باید بتوانید مدل خود را روی یک پردازنده گرافیکی T4 قرار دهید. داشتن منابع محدود مانند محدودیت فضا به این معنی است که باید به نحوه پیکربندی مدل خود توجه داشته باشید.

اگر هر پارامتر را به دقت تنظیم کنید، مدل شما قادر به اجرا در محیط نوت‌بوک نخواهد بود. در نتیجه، در این بخش از نوت‌بوک، مدل خود را طوری پیکربندی می‌کنید که قابلیت ثابت کردن برخی از پارامترها را داشته باشد و فقط پارامترهایی را که واقعاً برای تنظیم دقیق مدل نیاز دارند تا نتایج دقیقی به شما ارائه دهد، تنظیم کند. در LLMها، زمانی گفته می‌شود پارامترها ثابت شده‌اند که دیگر به طور فعال برای آموزش مدل استفاده نمی‌شوند.

برای پیکربندی مدل خود، باید:

  • مقداردهی اولیه model_config به عنوان FrozenConfigDict تا بتوانید برخی از پارامترها را فریز کنید و میزان استفاده از حافظه را پایین نگه دارید.
  • با استفاده از model_config به عنوان پیکربندی‌های آن، یک نمونه از کلاس Model PaliGemma را مقداردهی اولیه کنید.
  • پارامترهای مدل را در RAM بارگذاری کنید
  • تعریف یک تابع decode برای نمونه‌برداری از خروجی‌های مدل

اجرای کامل این کد در این سلول حدود یک دقیقه طول می‌کشد.

# Define model

# IMPORTANT: Gemma-2 has a "final_logits_softcap" property. Set it to 0.0
# for better transfer results.
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)

# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

پارامترهای مدل را به حافظه GPU/TPU منتقل کنید

حالا باید پارامترهای مدل را به حافظه GPU/TPU منتقل کنید. ابتدا پارامترها را در GPUهای موجود تقسیم‌بندی کنید، سپس پارامترها را بارگذاری کنید. در اینجا، پارامترها را به صورت متوالی بارگذاری خواهید کرد. این فرآیند بیشتر از بارگذاری همزمان آنها طول می‌کشد، اما به رم بیشتری نسبت به آنچه در این نوت‌بوک موجود است نیاز دارد.

در نهایت، تمام پارامترها را چاپ کنید تا ببینید هر پارامتر به چه نوعی تبدیل شده است. پارامترهای ثابت (freeze parameters) به صورت float16 نگه داشته می‌شوند، در حالی که پارامترهای قابل آموزش به float32 تبدیل می‌شوند. وقتی لیست را بررسی می‌کنید، خواهید دید که بیشتر پارامترها ثابت شده‌اند و float16 هستند.

# Create a pytree mask of the trainable params.
def is_trainable_param(name, param):  # pylint: disable=unused-argument
  if name.startswith("llm/layers/attn/"):  return True
  if name.startswith("llm/"):              return False
  if name.startswith("img/"):              return False
  raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)

# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))

data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  # Cast others to float16, since some GPUs don't support bf16.
  return jax.tree.map(lambda p, m: p.astype(jnp.float32)
                      if m else p.astype(jnp.float16),
                      params, trainable)

# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default.
# Instead, do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
  params[idx] = big_vision.utils.reshard(params[idx], sharding)
  params[idx] = maybe_cast_to_f32(params[idx], trainable)
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

# Print params to show what the model is made of.
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")

print(" == Model params == ")
parameter_overview(params)
== Model params == 
img/Transformer/encoder_norm/bias                                                (1152,)                float16
img/Transformer/encoder_norm/scale                                               (1152,)                float16
img/Transformer/encoderblock/LayerNorm_0/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_0/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias                             (27, 4304)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel                           (27, 1152, 4304)       float16
img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias                             (27, 1152)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel                           (27, 4304, 1152)       float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias             (27, 16, 72)           float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel           (27, 1152, 16, 72)     float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias             (27, 1152)             float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel           (27, 16, 72, 1152)     float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias           (27, 16, 72)           float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel         (27, 1152, 16, 72)     float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias           (27, 16, 72)           float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel         (27, 1152, 16, 72)     float16
img/embedding/bias                                                               (1152,)                float16
img/embedding/kernel                                                             (14, 14, 3, 1152)      float16
img/head/bias                                                                    (2304,)                float16
img/head/kernel                                                                  (1152, 2304)           float16
img/pos_embedding                                                                (1, 256, 1152)         float16
llm/embedder/input_embedding                                                     (257152, 2304)         float16
llm/final_norm/scale                                                             (2304,)                float16
llm/layers/attn/attn_vec_einsum/w                                                (26, 8, 256, 2304)     float32
llm/layers/attn/kv_einsum/w                                                      (26, 2, 4, 2304, 256)  float32
llm/layers/attn/q_einsum/w                                                       (26, 8, 2304, 256)     float32
llm/layers/mlp/gating_einsum                                                     (26, 2, 2304, 9216)    float16
llm/layers/mlp/linear                                                            (26, 9216, 2304)       float16
llm/layers/post_attention_norm/scale                                             (26, 2304)             float16
llm/layers/post_ffw_norm/scale                                                   (26, 2304)             float16
llm/layers/pre_attention_norm/scale                                              (26, 2304)             float16
llm/layers/pre_ffw_norm/scale                                                    (26, 2304)             float16

برای تنظیم مدل آماده شوید

اکنون که مدل شما پیکربندی شده است، می‌توانید آن را تنظیم کنید. در این مرحله، ورودی‌های مدل خود و همچنین تکرارکننده‌های آموزش و اعتبارسنجی را ایجاد خواهید کرد، مثال‌های آموزشی را مشاهده خواهید کرد و حلقه‌های آموزش و اعتبارسنجی را تعریف خواهید کرد.

ایجاد ورودی‌های مدل

مدل Checkpoint که شما استفاده می‌کنید، قبلاً روی تصاویری با نسبت ابعاد مختلف که به 224x224 پیکسل تغییر اندازه داده‌اند و برای مدیریت متون توکنیزه شده آموزش دیده است.

کد زیر سه تابع را تعریف می‌کند که در مرحله بعدی برای ایجاد ورودی‌های مدل از آنها استفاده خواهید کرد:

  • preprocess_image : داده‌های تصویر را نرمال‌سازی می‌کند. در این حالت، پیش‌پردازش، تصویر ورودی را به خاکستری تبدیل می‌کند، لایه آلفا را حذف می‌کند و تصویر ورودی را به اندازه مورد نیاز مدل برای ورودی‌های تصویر (۲۲۴x۲۲۴ پیکسل) تغییر اندازه می‌دهد.
  • preprocess_tokens : توکن‌ها را تقسیم می‌کند و پرچم‌هایی را برای مشخص کردن اینکه آیا یک توکن پیشوندی است یا پسوندی، اضافه می‌کند. این پرچم‌ها بعداً در کد، در طول مرحله آموزش و حلقه ارزیابی استفاده خواهند شد.
  • postprocess_tokens : هر توکنی که در و/یا بعد از توکن پایان توالی (EOS) باقی مانده باشد را حذف می‌کند و توکن‌های رمزگشایی شده باقی مانده را برمی‌گرداند.
def preprocess_image(image, size=224):
  # Model has been trained to handle images of different aspects ratios
  # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
  # options are helpful to improve quality in some tasks.
  image = np.asarray(image)
  if image.ndim == 2:  # Convert image without last channel into greyscale.
    image = np.stack((image,)*3, axis=-1)
  image = image[..., :3]  # Remove alpha layer.
  assert image.shape[-1] == 3

  image = tf.constant(image)
  image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
  return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]

def preprocess_tokens(prefix, suffix=None, seqlen=None):
  # Model has been trained to handle tokenized text composed of a prefix with
  # full attention and a suffix with causal attention.
  separator = "\n"
  tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
  mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.
  mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.

  if suffix:
    suffix = tokenizer.encode(suffix, add_eos=True)
    tokens += suffix
    mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.
    mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.

  mask_input = [1] * len(tokens)    # 1 if it's a token, 0 if padding.
  if seqlen:
    padding = [0] * max(0, seqlen - len(tokens))
    tokens = tokens[:seqlen] + padding
    mask_ar = mask_ar[:seqlen] + padding
    mask_loss = mask_loss[:seqlen] + padding
    mask_input = mask_input[:seqlen] + padding

  return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

def postprocess_tokens(tokens):
  tokens = tokens.tolist()  # np.array to list[int]
  try:  # Remove tokens at and after EOS if any.
    eos_pos = tokens.index(tokenizer.eos_id())
    tokens = tokens[:eos_pos]
  except ValueError:
    pass
  return tokenizer.decode(tokens)

ایجاد تکرارکننده‌های آموزش و اعتبارسنجی

دو تکرارکننده ایجاد کنید:

  • یک تکرارکننده‌ی آموزشی که به فرآیند آموزش اجازه می‌دهد داده‌ها را به صورت تکه‌تکه بررسی کند، نه اینکه همه را به طور همزمان پردازش کند.
    • این به شما امکان می‌دهد قبل از استفاده، برخی پیش‌پردازش‌ها را روی داده‌ها انجام دهید.
  • یک تکرارکننده اعتبارسنجی که به فرآیند آموزش اجازه می‌دهد تا روی مجموعه داده‌های اعتبارسنجی تکرار شود تا ببیند مدل تنظیم‌شده چقدر با نتایج ارائه‌شده همسو است.
SEQLEN = 128

train_dataset = big_vision.datasets.jsonl.DataSource(
    os.path.join(DATA_DIR, "data_train90.jsonl"),
    fopen_keys={"image": DATA_DIR})

val_dataset = big_vision.datasets.jsonl.DataSource(
    os.path.join(DATA_DIR, "data_val10.jsonl"),
    fopen_keys={"image": DATA_DIR})


def train_data_iterator():
  """Never ending iterator over training examples."""
  # Shuffle examples and repeat so one can train for many epochs.
  dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()
  for example in dataset.as_numpy_iterator():
    image = Image.open(io.BytesIO(example["image"]))
    image = preprocess_image(image)

    prefix = "caption en"  # Could also be a different prefix per example.
    suffix = example["suffix"].decode().lower()
    tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)

    yield {
        "image": np.asarray(image),
        "text": np.asarray(tokens),
        "mask_ar": np.asarray(mask_ar),
        "mask_loss": np.asarray(mask_loss),
    }


def validation_data_iterator():
  """Single iterator over validation examples."""
  for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():
    image = Image.open(io.BytesIO(example["image"]))
    image = preprocess_image(image)

    prefix = "caption en"  # Could also be a different prefix per example.
    tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)

    yield {
        "image": np.asarray(image),
        "text": np.asarray(tokens),
        "mask_ar": np.asarray(mask_ar),
        "mask_input": np.asarray(mask_input),
    }

مشاهده نمونه‌های آموزشی

در این دفترچه، داده‌های آموزشی شامل ۹۰ تصویر است که با توضیحات مفصلی از آنچه در تصویر نشان داده شده است، جفت شده‌اند.

کد زیر مجموعه‌ای تصادفی از تصاویر را به همراه توضیحات آنها از مجموعه داده‌های آموزشی چاپ می‌کند تا بتوانید ببینید تصاویر و توضیحاتی که مدل شما بر اساس آنها آموزش دیده است، چگونه به نظر می‌رسند. هر تصویر به صورت JPEG با ابعاد ۱۲۸x۱۲۸ پیکسل نمایش داده می‌شود و توضیحات آن در کنار تصویر سمت راست چاپ می‌شود.

def render_inline(image, resize=(128, 128)):
  """Convert image into inline html."""
  image = Image.fromarray(image)
  image.resize(resize)
  with io.BytesIO() as buffer:
    image.save(buffer, format='jpeg')
    image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
    return f"data:image/jpeg;base64,{image_b64}"

def render_example(image, caption):
  image = ((image + 1)/2 * 255).astype(np.uint8)  # [-1,1] -> [0, 255]
  return f"""
    <div style="display: inline-flex; align-items: center; justify-content: center;">
        <img style="width:128px; height:128px;" src="{render_inline(image, resize=(64,64))}" />
        <p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
    </div>
    """

html_out = ""
for idx, example in zip(range(8), train_data_iterator()):
  caption = postprocess_tokens(example["text"])  # detokenize model input.
  caption = caption[len("caption en\n"):]        # strip prefix
  html_out += render_example(example["image"], caption)

print("Training examples")
display(HTML(html_out))
Training examples

حلقه‌های آموزش و ارزیابی را تعریف کنید

حلقه آموزش را برای آموزش مدل روی مجموعه داده ارائه شده و حلقه ارزیابی را برای بررسی تمام مثال‌های موجود در مجموعه داده اعتبارسنجی و انجام پیش‌بینی‌های آن تعریف کنید.

تعریف حلقه آموزشی

تابع update_fn مرحله آموزش را تعریف می‌کند. در طول مرحله آموزش، میزان خطا به ازای هر مثال محاسبه می‌شود و گرادیان نزولی تصادفی (SGD) بر پارامترهای قابل آموزش اعمال می‌شود.

به یاد بیاورید که قبلاً در دفترچه یادداشت، پرچم‌هایی را در تابع preprocess_tokens قرار دادید که شامل mask_loss نیز می‌شدند. در اینجا از پرچم mask_loss برای حذف توکن‌های پیشوند و پر شده از زیان استفاده خواهید کرد. بدون آن، محاسبه زیان اریب خواهد بود. همچنین باید هر مثال را نرمال‌سازی کنید، زیرا هر یک از آنها تعداد متفاوتی توکن دارند. پس از حذف توکن‌های پیشوند و پر شده و نرمال‌سازی مثال‌ها، می‌توانید زیان را به ازای هر مثال محاسبه کنید.

مرحله آموزش همچنین شامل تابعی برای اعمال SGD برای بهینه‌سازی آموزش است.

تعریف حلقه ارزیابی

تابع make_predictions حلقه ارزیابی شماست. حلقه ارزیابی نسبتاً سرراست است، با یک تغییر قابل توجه. اگر از ابتدای دفترچه یادداشت به یاد داشته باشید، فقط ۹۰ مثال در مجموعه داده‌های آموزشی خود دارید. این تعداد بسیار کمی از مثال‌های آموزشی است و مدل شما در نهایت هنگام اجرای آموزش، مثال‌های کافی برای اندازه دسته ندارد. این بدان معناست که در حلقه ارزیابی، باید دسته را با تکرار مثال‌ها پر کنید.

برای اطمینان از اینکه حلقه ارزیابی شما فقط مثال‌های واقعی را شمارش می‌کند و نه مثال‌های پر شده را، باید یک ماسک به مثال‌های پر شده اعمال کنید که آنها را از خروجی حذف می‌کند.

# The main update_fn using a simple stochastic gradient descent (SGD).
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
  imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]

  def loss_fn(params):
    text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
    logp = jax.nn.log_softmax(text_logits, axis=-1)

    # The model takes as input txts[:, :-1] but the loss is defined as predicting
    # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
    # are part of the loss (e.g. prefix and padded tokens are not included).
    mask_loss = batch["mask_loss"][:, 1:]
    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])

    # Compute the loss per example. i.e. the mean of per token pplx.
    # Since each example has a different number of tokens, normalize it.
    token_pplx = jnp.sum(logp * targets, axis=-1)  # sum across vocab_size.
    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)  # sum across seq_len.
    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)  # weight by num of tokens.

    # batch_loss: mean of per example loss.
    return jnp.mean(example_loss)

  loss, grads = jax.value_and_grad(loss_fn)(params)

  # Apply gradients to trainable params using SGD.
  def apply_grad(param, gradient, trainable):
    if not trainable: return param
    return param - learning_rate * gradient

  params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)

  return params, loss

# Evaluation/inference loop.
def make_predictions(data_iterator, *, num_examples=None,
                     batch_size=4, seqlen=SEQLEN, sampler="greedy"):
  outputs = []
  while True:
    # Construct a list of examples in the batch.
    examples = []
    try:
      for _ in range(batch_size):
        examples.append(next(data_iterator))
        examples[-1]["_mask"] = np.array(True)  # Indicates true example.
    except StopIteration:
      if len(examples) == 0:
        return outputs

    # Not enough examples to complete a batch. Pad by repeating last example.
    while len(examples) % batch_size:
      examples.append(dict(examples[-1]))
      examples[-1]["_mask"] = np.array(False)  # Indicates padding example.

    # Convert list of examples into a dict of np.arrays and load onto devices.
    batch = jax.tree.map(lambda *x: np.stack(x), *examples)
    batch = big_vision.utils.reshard(batch, data_sharding)

    # Make model predictions
    tokens = decode({"params": params}, batch=batch,
                    max_decode_len=seqlen, sampler=sampler)

    # Fetch model predictions to device and detokenize.
    tokens, mask = jax.device_get((tokens, batch["_mask"]))
    tokens = tokens[mask]  # remove padding examples.
    responses = [postprocess_tokens(t) for t in tokens]

    # Append to html output.
    for example, response in zip(examples, responses):
      outputs.append((example["image"], response))
      if num_examples and len(outputs) >= num_examples:
        return outputs

مدل را تنظیم کنید

حالا که همه چیز را تنظیم کرده‌اید و نگاهی به داده‌های آموزشی انداخته‌اید، وقت آن رسیده که مدل را تنظیم کنید. کد زیر حلقه آموزشی را برای مدل به مدت ۶۴ مرحله اجرا می‌کند و نرخ یادگیری ( lr در خروجی چاپ شده) و نرخ تلفات را برای هر مرحله چاپ می‌کند.

هر ۱۶ مرحله، مدل پیش‌بینی‌های خود را در آن مرحله از آموزش چاپ می‌کند. این کد پیش‌بینی‌ها را برای همان مجموعه تصاویر چاپ می‌کند تا بتوانید ببینید که توانایی مدل در پیش‌بینی توصیفات با گذشت زمان بهبود می‌یابد.

در مراحل اولیه آموزش، احتمالاً مشکلاتی در توصیفات وجود دارد، مانند جملات تکراری که مدل در حلقه پیش‌بینی خود گیر می‌کند یا جملات ناتمام. پیش‌بینی‌های مدل با پیشرفت آموزش به طور پیوسته دقیق‌تر می‌شوند. تا مرحله ۶۴، پیش‌بینی‌های مدل باید شباهت زیادی به توصیفات ارائه شده توسط داده‌های آموزشی داشته باشند.

این فرآیند روی TPU های T4 حدود ۱۵ دقیقه طول می‌کشد.

# Run a short training loop with cosine learning rate schedule.
#
# Note: the first step can be quite slow on some machines (up to several minutes)
# due to XLA compilation of the jax.jit'd function.
#
%%time

BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.03

TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 4

train_data_it = train_data_iterator()

sched_fn = big_vision.utils.create_learning_rate_schedule(
    total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
    decay_type="cosine", warmup_percent=0.10)

for step in range(1, TRAIN_STEPS+1):
  # Make list of N training examples.
  examples = [next(train_data_it) for _ in range(BATCH_SIZE)]

  # Convert list of examples into a dict of np.arrays and load onto devices.
  batch = jax.tree.map(lambda *x: np.stack(x), *examples)
  batch = big_vision.utils.reshard(batch, data_sharding)

  # Training step and report training loss
  learning_rate = sched_fn(step)
  params, loss = update_fn(params, batch, learning_rate)

  loss = jax.device_get(loss)
  print(f"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}")

  if (step % EVAL_STEPS) == 0:
    print(f"Model predictions at step {step}")
    html_out = ""
    for image, caption in make_predictions(
        validation_data_iterator(), num_examples=4, batch_size=4):
      html_out += render_example(image, caption)
    display(HTML(html_out))
step:  1/64   lr: 0.00500   loss: 3.6567
step:  2/64   lr: 0.01000   loss: 1.9762
step:  3/64   lr: 0.01500   loss: 1.6299
step:  4/64   lr: 0.02000   loss: 1.5651
step:  5/64   lr: 0.02500   loss: 1.9813
step:  6/64   lr: 0.03000   loss: 1.9996
step:  7/64   lr: 0.02998   loss: 1.8595
step:  8/64   lr: 0.02992   loss: 1.6479
step:  9/64   lr: 0.02981   loss: 1.3693
step: 10/64   lr: 0.02966   loss: 1.3423
step: 11/64   lr: 0.02947   loss: 1.2122
step: 12/64   lr: 0.02924   loss: 1.0602
step: 13/64   lr: 0.02897   loss: 1.1314
step: 14/64   lr: 0.02866   loss: 1.2612
step: 15/64   lr: 0.02831   loss: 1.0132
step: 16/64   lr: 0.02792   loss: 1.2126
Model predictions at step 16
step: 17/64   lr: 0.02750   loss: 1.0986
step: 18/64   lr: 0.02704   loss: 0.9461
step: 19/64   lr: 0.02655   loss: 1.2098
step: 20/64   lr: 0.02602   loss: 1.0513
step: 21/64   lr: 0.02546   loss: 1.0979
step: 22/64   lr: 0.02488   loss: 0.9739
step: 23/64   lr: 0.02426   loss: 0.9589
step: 24/64   lr: 0.02362   loss: 0.7053
step: 25/64   lr: 0.02296   loss: 0.7347
step: 26/64   lr: 0.02227   loss: 0.6990
step: 27/64   lr: 0.02156   loss: 0.6736
step: 28/64   lr: 0.02083   loss: 0.6642
step: 29/64   lr: 0.02009   loss: 0.6908
step: 30/64   lr: 0.01933   loss: 0.7257
step: 31/64   lr: 0.01856   loss: 0.6902
step: 32/64   lr: 0.01778   loss: 0.7054
Model predictions at step 32
step: 33/64   lr: 0.01699   loss: 0.7709
step: 34/64   lr: 0.01620   loss: 0.6653
step: 35/64   lr: 0.01540   loss: 0.3811
step: 36/64   lr: 0.01460   loss: 0.3104
step: 37/64   lr: 0.01380   loss: 0.4042
step: 38/64   lr: 0.01301   loss: 0.3904
step: 39/64   lr: 0.01222   loss: 0.3339
step: 40/64   lr: 0.01144   loss: 0.4156
step: 41/64   lr: 0.01067   loss: 0.4085
step: 42/64   lr: 0.00991   loss: 0.3083
step: 43/64   lr: 0.00917   loss: 0.3757
step: 44/64   lr: 0.00844   loss: 0.3813
step: 45/64   lr: 0.00773   loss: 0.3381
step: 46/64   lr: 0.00704   loss: 0.2057
step: 47/64   lr: 0.00638   loss: 0.1287
step: 48/64   lr: 0.00574   loss: 0.1711
Model predictions at step 48
step: 49/64   lr: 0.00512   loss: 0.1183
step: 50/64   lr: 0.00454   loss: 0.1154
step: 51/64   lr: 0.00398   loss: 0.1967
step: 52/64   lr: 0.00345   loss: 0.1497
step: 53/64   lr: 0.00296   loss: 0.1688
step: 54/64   lr: 0.00250   loss: 0.1878
step: 55/64   lr: 0.00208   loss: 0.1865
step: 56/64   lr: 0.00169   loss: 0.1655
step: 57/64   lr: 0.00134   loss: 0.0911
step: 58/64   lr: 0.00103   loss: 0.1836
step: 59/64   lr: 0.00076   loss: 0.1242
step: 60/64   lr: 0.00053   loss: 0.0814
step: 61/64   lr: 0.00034   loss: 0.0866
step: 62/64   lr: 0.00019   loss: 0.1295
step: 63/64   lr: 0.00008   loss: 0.1053
step: 64/64   lr: 0.00002   loss: 0.0730
Model predictions at step 64
CPU times: user 2min 18s, sys: 8.98 s, total: 2min 27s
Wall time: 15min 45s

خروجی

داده‌های اعتبارسنجی این دفترچه یادداشت فقط شامل ۱۰ تصویر است. در کد معمولی، احتمالاً نقاط داده بسیار بیشتری برای اعتبارسنجی خواهید داشت، اما برای این دفترچه یادداشت، کد زیر را اجرا کنید تا توضیحاتی برای هر ۱۰ تصویر ایجاد شود. پس از تنظیم مدل، این توضیحات باید از نظر شکل و پوشش محتوا بسیار شبیه به توضیحات موجود در داده‌های آموزشی باشند که قبلاً در این دفترچه یادداشت به آنها نگاه کردید.

کد زیر را اجرا کنید تا توضیحاتی برای مجموعه داده‌های اعتبارسنجی ایجاد شود.

# The validation data consists of 10 images in a different domain than training
# data.
%%time

print("Model predictions")
html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
  html_out += render_example(image, caption)
display(HTML(html_out))
Model predictions
CPU times: user 1.87 s, sys: 283 ms, total: 2.15 s
Wall time: 39.3 s