| | در گوگل کولب اجرا کنید | | | مشاهده منبع در گیتهاب |
این دفترچه یادداشت نحوه تنظیم دقیق PaliGemma را در یک کار زبان بینایی با JAX نشان میدهد. تنظیم دقیق فرآیندی است که میتواند عملکرد مدل شما را در وظایف خاص بهبود بخشد یا به مدل کمک کند تا در صورت کافی نبودن دستورالعملها و داشتن مجموعهای از مثالها که خروجیهای مورد نظر شما را نشان میدهند، به الزامات خروجی خاص پایبند باشد. مدلهای مبتنی بر Gemma مانند PaliGemma برای تولید نتایج مورد انتظار نیاز به تنظیم دقیق دارند.
آنچه در این دفترچه یادداشت است
این دفترچه یادداشت از پیادهسازی مرجع مدل از big_vision استفاده میکند و نحوهی انجام موارد زیر را نشان میدهد:
- وابستگیها را نصب کنید و ایست بازرسی مدل PaliGemma و دادههای آموزشی را دانلود کنید
- مدل را روی دستگاههای GPU بارگذاری کنید
- آمادهسازی ورودیهای مدل برای آموزش و استنتاج
- مدل را دقیق تنظیم کنید
- خروجی را بررسی کنید
دادههای آموزشی این دفترچه یادداشت شامل ۹۰ جفت تصویر و توضیحات طولانی است که آنها را توصیف میکند. برای اینکه بتوان آن را در یک محیط برنامهنویسی T4 اجرا کرد، فقط لایههای توجه مدل زبان را تنظیم دقیق کرده و سایر پارامترها را ثابت نگه میدارید.
این مثال فقط برای اهداف یادگیری است. در یک مورد استفاده واقعی، مقدار دادهها، پارامترهای قابل آموزش، مراحل آموزش و ابرپارامترها و نتایج بهدستآمده میتوانند بهطور قابلتوجهی متفاوت باشند.
قبل از اینکه شروع کنی
قبل از مطالعهی این دفترچه، باید با کد پایتون و همچنین نحوهی آموزش مدلهای زبانی بزرگ (LLM) آشنا باشید. نیازی به آشنایی با JAX نیست، اما دانش اولیه در مورد JAX (یا فناوریهای مشابه مانند Keras) هنگام خواندن کد نمونه مفید است.
راهاندازی
بخشهای زیر مراحل اولیه برای استفاده از مدل PaliGemma در یک نوتبوک، از جمله دسترسی به مدل، دریافت کلید API و پیکربندی زمان اجرای نوتبوک را توضیح میدهند.
به PaliGemma دسترسی پیدا کنید
قبل از اولین استفاده از PaliGemma، باید با انجام مراحل زیر، از طریق Kaggle درخواست دسترسی به مدل را بدهید:
- وارد حساب کاگل شوید، یا اگر از قبل حساب کاگل ندارید، یک حساب کاگل جدید ایجاد کنید.
- به کارت مدل PaliGemma بروید و روی درخواست دسترسی کلیک کنید.
- فرم رضایتنامه را تکمیل کنید و شرایط و ضوابط را بپذیرید.
کلید API خود را پیکربندی کنید
برای استفاده از PaliGemma، باید نام کاربری Kaggle و یک کلید API Kaggle خود را ارائه دهید.
برای تولید کلید API کاگل، صفحه تنظیمات خود را در کاگل باز کنید و روی ایجاد توکن جدید کلیک کنید. این کار باعث دانلود فایل kaggle.json حاوی اطلاعات احراز هویت API شما میشود.
سپس، در Colab، در پنل سمت چپ، گزینه Secrets (🔑) را انتخاب کنید و نام کاربری Kaggle و کلید API Kaggle خود را اضافه کنید. نام کاربری خود را با نام KAGGLE_USERNAME و کلید API خود را با نام KAGGLE_KEY ذخیره کنید.
زمان اجرا را انتخاب کنید
برای تکمیل این آموزش، به یک محیط اجرای Colab با منابع کافی برای اجرای مدل PaliGemma نیاز دارید. در این حالت، میتوانید از یک پردازنده گرافیکی T4 استفاده کنید:
- در سمت راست بالای پنجره Colab، روی منوی کشویی ▾ (گزینههای اتصال اضافی) کلیک کنید.
- تغییر نوع زمان اجرا را انتخاب کنید.
- در قسمت شتابدهنده سختافزاری ، T4 GPU را انتخاب کنید.
نصب بستههای پایتون
برای نصب KaggleHub، سلول زیر را اجرا کنید.
pip install -U -q kagglehubتنظیم متغیرهای محیطی
متغیرهای محیطی و ورود به سیستم Kaggle را تنظیم کنید.
import os
# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid out-of-memory due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
# This will prompt you to enter your Kaggle API token
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle… Kaggle credentials set. Kaggle credentials successfully validated.
مخزن big_vision را دریافت کنید و وابستگیهای مرتبط را نصب کنید
مخزن big_vision را از گیتهاب روی نوتبوک Colab خود دانلود کنید و با اجرای کد زیر، وابستگیهای مربوط به big_vision را نصب کنید.
import os
import sys
# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
raise "It seems you are using Colab with remote TPUs which is not supported."
# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
!git clone --quiet --branch=main --depth=1 \
https://github.com/google-research/big_vision big_vision_repo
# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
sys.path.append("big_vision_repo")
# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.7/76.7 kB 2.8 MB/s eta 0:00:00
وارد کردن JAX و سایر وابستگیها
JAX و سایر وابستگیهای مورد نیاز برای PaliGemma، مانند TensorFlow و NumPy را وارد کنید.
import base64
import functools
import html
import io
import os
import warnings
import jax
import jax.numpy as jnp
import numpy as np
import ml_collections
import tensorflow as tf
import sentencepiece
from IPython.core.display import display, HTML
from PIL import Image
# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding
# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")
backend = jax.extend.backend.get_backend()
print(f"JAX version: {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices: {jax.device_count()}")
JAX version: 0.7.2 JAX platform: gpu JAX devices: 1
مدل را دانلود و پیکربندی کنید
در این مرحله، شما نقطه کنترل مدل را دانلود و پیکربندی میکنید تا بتوانید بعداً آن را به طور دقیق تنظیم کنید. این مرحله به شما نشان میدهد که چگونه پارامترهای مدل را به حافظه TPU منتقل کنید، که برای تنظیم دقیق مدلها در دستگاههایی با منابع محدود مفید است.
دانلود نمونه چک پوینت
PaliGemma شامل چندین مدل مختلف است. برای این آموزش، از مدل وزن پایه JAX/FLAX PaliGemma 3B استفاده خواهید کرد.
با اجرای کد زیر، نقطه بررسی مدل را از Kaggle دانلود کنید. این فرآیند چند دقیقه طول میکشد تا تکمیل شود.
import os
import kagglehub
# Use these for PaliGemma-2 3B 224px²
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "./paligemma2-3b-pt-224.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224" # Path to fetch from Kaggle.
# Use these for PaliGemma 1:
# LLM_VARIANT = "gemma_2b"
# MODEL_PATH = "./paligemma-3b-pt-224.f16.npz"
# KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"
if not os.path.exists(MODEL_PATH):
print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
print(f"Model path: {MODEL_PATH}")
TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
print("Downloading the model tokenizer...")
!gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
print(f"Tokenizer path: {TOKENIZER_PATH}")
DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
print("Downloading the dataset...")
!gsutil -m -q cp -n -r gs://longcap100/ .
print(f"Data path: {DATA_DIR}")
Downloading the checkpoint from Kaggle, this could take a few minutes.... Downloading to /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz... 100%|██████████| 5.65G/5.65G [00:54<00:00, 112MB/s] Model path: /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz Downloading the model tokenizer... Copying gs://big_vision/paligemma_tokenizer.model... - [1 files][ 4.1 MiB/ 4.1 MiB] Operation completed over 1 objects/4.1 MiB. Tokenizer path: ./paligemma_tokenizer.model Downloading the dataset... Data path: ./longcap100
پیکربندی مدل
وقت آن است که عملاً پیکربندی مدلی را که قرار است استفاده کنید، شروع کنید.
برای این نوتبوک، باید بتوانید مدل خود را روی یک پردازنده گرافیکی T4 قرار دهید. داشتن منابع محدود مانند محدودیت فضا به این معنی است که باید به نحوه پیکربندی مدل خود توجه داشته باشید.
اگر هر پارامتر را به دقت تنظیم کنید، مدل شما قادر به اجرا در محیط نوتبوک نخواهد بود. در نتیجه، در این بخش از نوتبوک، مدل خود را طوری پیکربندی میکنید که قابلیت ثابت کردن برخی از پارامترها را داشته باشد و فقط پارامترهایی را که واقعاً برای تنظیم دقیق مدل نیاز دارند تا نتایج دقیقی به شما ارائه دهد، تنظیم کند. در LLMها، زمانی گفته میشود پارامترها ثابت شدهاند که دیگر به طور فعال برای آموزش مدل استفاده نمیشوند.
برای پیکربندی مدل خود، باید:
- مقداردهی اولیه
model_configبه عنوانFrozenConfigDictتا بتوانید برخی از پارامترها را فریز کنید و میزان استفاده از حافظه را پایین نگه دارید. - با استفاده از
model_configبه عنوان پیکربندیهای آن، یک نمونه از کلاسModelPaliGemma را مقداردهی اولیه کنید. - پارامترهای مدل را در RAM بارگذاری کنید
- تعریف یک تابع
decodeبرای نمونهبرداری از خروجیهای مدل
اجرای کامل این کد در این سلول حدود یک دقیقه طول میکشد.
# Define model
# IMPORTANT: Gemma-2 has a "final_logits_softcap" property. Set it to 0.0
# for better transfer results.
model_config = ml_collections.FrozenConfigDict({
"llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
"img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)
# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)
# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())
پارامترهای مدل را به حافظه GPU/TPU منتقل کنید
حالا باید پارامترهای مدل را به حافظه GPU/TPU منتقل کنید. ابتدا پارامترها را در GPUهای موجود تقسیمبندی کنید، سپس پارامترها را بارگذاری کنید. در اینجا، پارامترها را به صورت متوالی بارگذاری خواهید کرد. این فرآیند بیشتر از بارگذاری همزمان آنها طول میکشد، اما به رم بیشتری نسبت به آنچه در این نوتبوک موجود است نیاز دارد.
در نهایت، تمام پارامترها را چاپ کنید تا ببینید هر پارامتر به چه نوعی تبدیل شده است. پارامترهای ثابت (freeze parameters) به صورت float16 نگه داشته میشوند، در حالی که پارامترهای قابل آموزش به float32 تبدیل میشوند. وقتی لیست را بررسی میکنید، خواهید دید که بیشتر پارامترها ثابت شدهاند و float16 هستند.
# Create a pytree mask of the trainable params.
def is_trainable_param(name, param): # pylint: disable=unused-argument
if name.startswith("llm/layers/attn/"): return True
if name.startswith("llm/"): return False
if name.startswith("img/"): return False
raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))
data_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("data"))
params_sharding = big_vision.sharding.infer_sharding(
params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)
# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable")
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
# Cast others to float16, since some GPUs don't support bf16.
return jax.tree.map(lambda p, m: p.astype(jnp.float32)
if m else p.astype(jnp.float16),
params, trainable)
# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default.
# Instead, do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
params[idx] = big_vision.utils.reshard(params[idx], sharding)
params[idx] = maybe_cast_to_f32(params[idx], trainable)
params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)
# Print params to show what the model is made of.
def parameter_overview(params):
for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")
print(" == Model params == ")
parameter_overview(params)
== Model params == img/Transformer/encoder_norm/bias (1152,) float16 img/Transformer/encoder_norm/scale (1152,) float16 img/Transformer/encoderblock/LayerNorm_0/bias (27, 1152) float16 img/Transformer/encoderblock/LayerNorm_0/scale (27, 1152) float16 img/Transformer/encoderblock/LayerNorm_1/bias (27, 1152) float16 img/Transformer/encoderblock/LayerNorm_1/scale (27, 1152) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias (27, 4304) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel (27, 1152, 4304) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias (27, 1152) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel (27, 4304, 1152) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias (27, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel (27, 1152, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias (27, 1152) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel (27, 16, 72, 1152) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias (27, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel (27, 1152, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias (27, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel (27, 1152, 16, 72) float16 img/embedding/bias (1152,) float16 img/embedding/kernel (14, 14, 3, 1152) float16 img/head/bias (2304,) float16 img/head/kernel (1152, 2304) float16 img/pos_embedding (1, 256, 1152) float16 llm/embedder/input_embedding (257152, 2304) float16 llm/final_norm/scale (2304,) float16 llm/layers/attn/attn_vec_einsum/w (26, 8, 256, 2304) float32 llm/layers/attn/kv_einsum/w (26, 2, 4, 2304, 256) float32 llm/layers/attn/q_einsum/w (26, 8, 2304, 256) float32 llm/layers/mlp/gating_einsum (26, 2, 2304, 9216) float16 llm/layers/mlp/linear (26, 9216, 2304) float16 llm/layers/post_attention_norm/scale (26, 2304) float16 llm/layers/post_ffw_norm/scale (26, 2304) float16 llm/layers/pre_attention_norm/scale (26, 2304) float16 llm/layers/pre_ffw_norm/scale (26, 2304) float16
برای تنظیم مدل آماده شوید
اکنون که مدل شما پیکربندی شده است، میتوانید آن را تنظیم کنید. در این مرحله، ورودیهای مدل خود و همچنین تکرارکنندههای آموزش و اعتبارسنجی را ایجاد خواهید کرد، مثالهای آموزشی را مشاهده خواهید کرد و حلقههای آموزش و اعتبارسنجی را تعریف خواهید کرد.
ایجاد ورودیهای مدل
مدل Checkpoint که شما استفاده میکنید، قبلاً روی تصاویری با نسبت ابعاد مختلف که به 224x224 پیکسل تغییر اندازه دادهاند و برای مدیریت متون توکنیزه شده آموزش دیده است.
کد زیر سه تابع را تعریف میکند که در مرحله بعدی برای ایجاد ورودیهای مدل از آنها استفاده خواهید کرد:
-
preprocess_image: دادههای تصویر را نرمالسازی میکند. در این حالت، پیشپردازش، تصویر ورودی را به خاکستری تبدیل میکند، لایه آلفا را حذف میکند و تصویر ورودی را به اندازه مورد نیاز مدل برای ورودیهای تصویر (۲۲۴x۲۲۴ پیکسل) تغییر اندازه میدهد. -
preprocess_tokens: توکنها را تقسیم میکند و پرچمهایی را برای مشخص کردن اینکه آیا یک توکن پیشوندی است یا پسوندی، اضافه میکند. این پرچمها بعداً در کد، در طول مرحله آموزش و حلقه ارزیابی استفاده خواهند شد. -
postprocess_tokens: هر توکنی که در و/یا بعد از توکن پایان توالی (EOS) باقی مانده باشد را حذف میکند و توکنهای رمزگشایی شده باقی مانده را برمیگرداند.
def preprocess_image(image, size=224):
# Model has been trained to handle images of different aspects ratios
# resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
# options are helpful to improve quality in some tasks.
image = np.asarray(image)
if image.ndim == 2: # Convert image without last channel into greyscale.
image = np.stack((image,)*3, axis=-1)
image = image[..., :3] # Remove alpha layer.
assert image.shape[-1] == 3
image = tf.constant(image)
image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]
def preprocess_tokens(prefix, suffix=None, seqlen=None):
# Model has been trained to handle tokenized text composed of a prefix with
# full attention and a suffix with causal attention.
separator = "\n"
tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.
mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.
if suffix:
suffix = tokenizer.encode(suffix, add_eos=True)
tokens += suffix
mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.
mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.
mask_input = [1] * len(tokens) # 1 if it's a token, 0 if padding.
if seqlen:
padding = [0] * max(0, seqlen - len(tokens))
tokens = tokens[:seqlen] + padding
mask_ar = mask_ar[:seqlen] + padding
mask_loss = mask_loss[:seqlen] + padding
mask_input = mask_input[:seqlen] + padding
return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))
def postprocess_tokens(tokens):
tokens = tokens.tolist() # np.array to list[int]
try: # Remove tokens at and after EOS if any.
eos_pos = tokens.index(tokenizer.eos_id())
tokens = tokens[:eos_pos]
except ValueError:
pass
return tokenizer.decode(tokens)
ایجاد تکرارکنندههای آموزش و اعتبارسنجی
دو تکرارکننده ایجاد کنید:
- یک تکرارکنندهی آموزشی که به فرآیند آموزش اجازه میدهد دادهها را به صورت تکهتکه بررسی کند، نه اینکه همه را به طور همزمان پردازش کند.
- این به شما امکان میدهد قبل از استفاده، برخی پیشپردازشها را روی دادهها انجام دهید.
- یک تکرارکننده اعتبارسنجی که به فرآیند آموزش اجازه میدهد تا روی مجموعه دادههای اعتبارسنجی تکرار شود تا ببیند مدل تنظیمشده چقدر با نتایج ارائهشده همسو است.
SEQLEN = 128
train_dataset = big_vision.datasets.jsonl.DataSource(
os.path.join(DATA_DIR, "data_train90.jsonl"),
fopen_keys={"image": DATA_DIR})
val_dataset = big_vision.datasets.jsonl.DataSource(
os.path.join(DATA_DIR, "data_val10.jsonl"),
fopen_keys={"image": DATA_DIR})
def train_data_iterator():
"""Never ending iterator over training examples."""
# Shuffle examples and repeat so one can train for many epochs.
dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()
for example in dataset.as_numpy_iterator():
image = Image.open(io.BytesIO(example["image"]))
image = preprocess_image(image)
prefix = "caption en" # Could also be a different prefix per example.
suffix = example["suffix"].decode().lower()
tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_loss": np.asarray(mask_loss),
}
def validation_data_iterator():
"""Single iterator over validation examples."""
for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():
image = Image.open(io.BytesIO(example["image"]))
image = preprocess_image(image)
prefix = "caption en" # Could also be a different prefix per example.
tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_input": np.asarray(mask_input),
}
مشاهده نمونههای آموزشی
در این دفترچه، دادههای آموزشی شامل ۹۰ تصویر است که با توضیحات مفصلی از آنچه در تصویر نشان داده شده است، جفت شدهاند.
کد زیر مجموعهای تصادفی از تصاویر را به همراه توضیحات آنها از مجموعه دادههای آموزشی چاپ میکند تا بتوانید ببینید تصاویر و توضیحاتی که مدل شما بر اساس آنها آموزش دیده است، چگونه به نظر میرسند. هر تصویر به صورت JPEG با ابعاد ۱۲۸x۱۲۸ پیکسل نمایش داده میشود و توضیحات آن در کنار تصویر سمت راست چاپ میشود.
def render_inline(image, resize=(128, 128)):
"""Convert image into inline html."""
image = Image.fromarray(image)
image.resize(resize)
with io.BytesIO() as buffer:
image.save(buffer, format='jpeg')
image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
return f"data:image/jpeg;base64,{image_b64}"
def render_example(image, caption):
image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]
return f"""
<div style="display: inline-flex; align-items: center; justify-content: center;">
<img style="width:128px; height:128px;" src="{render_inline(image, resize=(64,64))}" />
<p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
</div>
"""
html_out = ""
for idx, example in zip(range(8), train_data_iterator()):
caption = postprocess_tokens(example["text"]) # detokenize model input.
caption = caption[len("caption en\n"):] # strip prefix
html_out += render_example(example["image"], caption)
print("Training examples")
display(HTML(html_out))
Training examples
حلقههای آموزش و ارزیابی را تعریف کنید
حلقه آموزش را برای آموزش مدل روی مجموعه داده ارائه شده و حلقه ارزیابی را برای بررسی تمام مثالهای موجود در مجموعه داده اعتبارسنجی و انجام پیشبینیهای آن تعریف کنید.
تعریف حلقه آموزشی
تابع update_fn مرحله آموزش را تعریف میکند. در طول مرحله آموزش، میزان خطا به ازای هر مثال محاسبه میشود و گرادیان نزولی تصادفی (SGD) بر پارامترهای قابل آموزش اعمال میشود.
به یاد بیاورید که قبلاً در دفترچه یادداشت، پرچمهایی را در تابع preprocess_tokens قرار دادید که شامل mask_loss نیز میشدند. در اینجا از پرچم mask_loss برای حذف توکنهای پیشوند و پر شده از زیان استفاده خواهید کرد. بدون آن، محاسبه زیان اریب خواهد بود. همچنین باید هر مثال را نرمالسازی کنید، زیرا هر یک از آنها تعداد متفاوتی توکن دارند. پس از حذف توکنهای پیشوند و پر شده و نرمالسازی مثالها، میتوانید زیان را به ازای هر مثال محاسبه کنید.
مرحله آموزش همچنین شامل تابعی برای اعمال SGD برای بهینهسازی آموزش است.
تعریف حلقه ارزیابی
تابع make_predictions حلقه ارزیابی شماست. حلقه ارزیابی نسبتاً سرراست است، با یک تغییر قابل توجه. اگر از ابتدای دفترچه یادداشت به یاد داشته باشید، فقط ۹۰ مثال در مجموعه دادههای آموزشی خود دارید. این تعداد بسیار کمی از مثالهای آموزشی است و مدل شما در نهایت هنگام اجرای آموزش، مثالهای کافی برای اندازه دسته ندارد. این بدان معناست که در حلقه ارزیابی، باید دسته را با تکرار مثالها پر کنید.
برای اطمینان از اینکه حلقه ارزیابی شما فقط مثالهای واقعی را شمارش میکند و نه مثالهای پر شده را، باید یک ماسک به مثالهای پر شده اعمال کنید که آنها را از خروجی حذف میکند.
# The main update_fn using a simple stochastic gradient descent (SGD).
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]
def loss_fn(params):
text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
logp = jax.nn.log_softmax(text_logits, axis=-1)
# The model takes as input txts[:, :-1] but the loss is defined as predicting
# next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
# are part of the loss (e.g. prefix and padded tokens are not included).
mask_loss = batch["mask_loss"][:, 1:]
targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])
# Compute the loss per example. i.e. the mean of per token pplx.
# Since each example has a different number of tokens, normalize it.
token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.
example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.
example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.
# batch_loss: mean of per example loss.
return jnp.mean(example_loss)
loss, grads = jax.value_and_grad(loss_fn)(params)
# Apply gradients to trainable params using SGD.
def apply_grad(param, gradient, trainable):
if not trainable: return param
return param - learning_rate * gradient
params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)
return params, loss
# Evaluation/inference loop.
def make_predictions(data_iterator, *, num_examples=None,
batch_size=4, seqlen=SEQLEN, sampler="greedy"):
outputs = []
while True:
# Construct a list of examples in the batch.
examples = []
try:
for _ in range(batch_size):
examples.append(next(data_iterator))
examples[-1]["_mask"] = np.array(True) # Indicates true example.
except StopIteration:
if len(examples) == 0:
return outputs
# Not enough examples to complete a batch. Pad by repeating last example.
while len(examples) % batch_size:
examples.append(dict(examples[-1]))
examples[-1]["_mask"] = np.array(False) # Indicates padding example.
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Make model predictions
tokens = decode({"params": params}, batch=batch,
max_decode_len=seqlen, sampler=sampler)
# Fetch model predictions to device and detokenize.
tokens, mask = jax.device_get((tokens, batch["_mask"]))
tokens = tokens[mask] # remove padding examples.
responses = [postprocess_tokens(t) for t in tokens]
# Append to html output.
for example, response in zip(examples, responses):
outputs.append((example["image"], response))
if num_examples and len(outputs) >= num_examples:
return outputs
مدل را تنظیم کنید
حالا که همه چیز را تنظیم کردهاید و نگاهی به دادههای آموزشی انداختهاید، وقت آن رسیده که مدل را تنظیم کنید. کد زیر حلقه آموزشی را برای مدل به مدت ۶۴ مرحله اجرا میکند و نرخ یادگیری ( lr در خروجی چاپ شده) و نرخ تلفات را برای هر مرحله چاپ میکند.
هر ۱۶ مرحله، مدل پیشبینیهای خود را در آن مرحله از آموزش چاپ میکند. این کد پیشبینیها را برای همان مجموعه تصاویر چاپ میکند تا بتوانید ببینید که توانایی مدل در پیشبینی توصیفات با گذشت زمان بهبود مییابد.
در مراحل اولیه آموزش، احتمالاً مشکلاتی در توصیفات وجود دارد، مانند جملات تکراری که مدل در حلقه پیشبینی خود گیر میکند یا جملات ناتمام. پیشبینیهای مدل با پیشرفت آموزش به طور پیوسته دقیقتر میشوند. تا مرحله ۶۴، پیشبینیهای مدل باید شباهت زیادی به توصیفات ارائه شده توسط دادههای آموزشی داشته باشند.
این فرآیند روی TPU های T4 حدود ۱۵ دقیقه طول میکشد.
# Run a short training loop with cosine learning rate schedule.
#
# Note: the first step can be quite slow on some machines (up to several minutes)
# due to XLA compilation of the jax.jit'd function.
#
%%time
BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.03
TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 4
train_data_it = train_data_iterator()
sched_fn = big_vision.utils.create_learning_rate_schedule(
total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
decay_type="cosine", warmup_percent=0.10)
for step in range(1, TRAIN_STEPS+1):
# Make list of N training examples.
examples = [next(train_data_it) for _ in range(BATCH_SIZE)]
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Training step and report training loss
learning_rate = sched_fn(step)
params, loss = update_fn(params, batch, learning_rate)
loss = jax.device_get(loss)
print(f"step: {step:2d}/{TRAIN_STEPS:2d} lr: {learning_rate:.5f} loss: {loss:.4f}")
if (step % EVAL_STEPS) == 0:
print(f"Model predictions at step {step}")
html_out = ""
for image, caption in make_predictions(
validation_data_iterator(), num_examples=4, batch_size=4):
html_out += render_example(image, caption)
display(HTML(html_out))
step: 1/64 lr: 0.00500 loss: 3.6567 step: 2/64 lr: 0.01000 loss: 1.9762 step: 3/64 lr: 0.01500 loss: 1.6299 step: 4/64 lr: 0.02000 loss: 1.5651 step: 5/64 lr: 0.02500 loss: 1.9813 step: 6/64 lr: 0.03000 loss: 1.9996 step: 7/64 lr: 0.02998 loss: 1.8595 step: 8/64 lr: 0.02992 loss: 1.6479 step: 9/64 lr: 0.02981 loss: 1.3693 step: 10/64 lr: 0.02966 loss: 1.3423 step: 11/64 lr: 0.02947 loss: 1.2122 step: 12/64 lr: 0.02924 loss: 1.0602 step: 13/64 lr: 0.02897 loss: 1.1314 step: 14/64 lr: 0.02866 loss: 1.2612 step: 15/64 lr: 0.02831 loss: 1.0132 step: 16/64 lr: 0.02792 loss: 1.2126 Model predictions at step 16
step: 17/64 lr: 0.02750 loss: 1.0986 step: 18/64 lr: 0.02704 loss: 0.9461 step: 19/64 lr: 0.02655 loss: 1.2098 step: 20/64 lr: 0.02602 loss: 1.0513 step: 21/64 lr: 0.02546 loss: 1.0979 step: 22/64 lr: 0.02488 loss: 0.9739 step: 23/64 lr: 0.02426 loss: 0.9589 step: 24/64 lr: 0.02362 loss: 0.7053 step: 25/64 lr: 0.02296 loss: 0.7347 step: 26/64 lr: 0.02227 loss: 0.6990 step: 27/64 lr: 0.02156 loss: 0.6736 step: 28/64 lr: 0.02083 loss: 0.6642 step: 29/64 lr: 0.02009 loss: 0.6908 step: 30/64 lr: 0.01933 loss: 0.7257 step: 31/64 lr: 0.01856 loss: 0.6902 step: 32/64 lr: 0.01778 loss: 0.7054 Model predictions at step 32
step: 33/64 lr: 0.01699 loss: 0.7709 step: 34/64 lr: 0.01620 loss: 0.6653 step: 35/64 lr: 0.01540 loss: 0.3811 step: 36/64 lr: 0.01460 loss: 0.3104 step: 37/64 lr: 0.01380 loss: 0.4042 step: 38/64 lr: 0.01301 loss: 0.3904 step: 39/64 lr: 0.01222 loss: 0.3339 step: 40/64 lr: 0.01144 loss: 0.4156 step: 41/64 lr: 0.01067 loss: 0.4085 step: 42/64 lr: 0.00991 loss: 0.3083 step: 43/64 lr: 0.00917 loss: 0.3757 step: 44/64 lr: 0.00844 loss: 0.3813 step: 45/64 lr: 0.00773 loss: 0.3381 step: 46/64 lr: 0.00704 loss: 0.2057 step: 47/64 lr: 0.00638 loss: 0.1287 step: 48/64 lr: 0.00574 loss: 0.1711 Model predictions at step 48
step: 49/64 lr: 0.00512 loss: 0.1183 step: 50/64 lr: 0.00454 loss: 0.1154 step: 51/64 lr: 0.00398 loss: 0.1967 step: 52/64 lr: 0.00345 loss: 0.1497 step: 53/64 lr: 0.00296 loss: 0.1688 step: 54/64 lr: 0.00250 loss: 0.1878 step: 55/64 lr: 0.00208 loss: 0.1865 step: 56/64 lr: 0.00169 loss: 0.1655 step: 57/64 lr: 0.00134 loss: 0.0911 step: 58/64 lr: 0.00103 loss: 0.1836 step: 59/64 lr: 0.00076 loss: 0.1242 step: 60/64 lr: 0.00053 loss: 0.0814 step: 61/64 lr: 0.00034 loss: 0.0866 step: 62/64 lr: 0.00019 loss: 0.1295 step: 63/64 lr: 0.00008 loss: 0.1053 step: 64/64 lr: 0.00002 loss: 0.0730 Model predictions at step 64
CPU times: user 2min 18s, sys: 8.98 s, total: 2min 27s Wall time: 15min 45s
خروجی
دادههای اعتبارسنجی این دفترچه یادداشت فقط شامل ۱۰ تصویر است. در کد معمولی، احتمالاً نقاط داده بسیار بیشتری برای اعتبارسنجی خواهید داشت، اما برای این دفترچه یادداشت، کد زیر را اجرا کنید تا توضیحاتی برای هر ۱۰ تصویر ایجاد شود. پس از تنظیم مدل، این توضیحات باید از نظر شکل و پوشش محتوا بسیار شبیه به توضیحات موجود در دادههای آموزشی باشند که قبلاً در این دفترچه یادداشت به آنها نگاه کردید.
کد زیر را اجرا کنید تا توضیحاتی برای مجموعه دادههای اعتبارسنجی ایجاد شود.
# The validation data consists of 10 images in a different domain than training
# data.
%%time
print("Model predictions")
html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
html_out += render_example(image, caption)
display(HTML(html_out))
Model predictions
CPU times: user 1.87 s, sys: 283 ms, total: 2.15 s Wall time: 39.3 s
در گوگل کولب اجرا کنید
مشاهده منبع در گیتهاب