JAX এবং Flax দিয়ে PaliGemma-কে সূক্ষ্মভাবে সাজিয়ে নিন

ai.google.dev-এ দেখুন গুগল কোলাবে চালান কাগলে দৌড়ান ভার্টেক্স এআই-তে খুলুন GitHub-এ উৎস দেখুন

এই নোটবুকটি JAX ব্যবহার করে একটি ভিশন-ল্যাঙ্গুয়েজ টাস্কে PaliGemma-কে কীভাবে সূক্ষ্ম-টিউন করতে হয় তা দেখায়। ফাইন-টিউনিং এমন একটি প্রক্রিয়া যা নির্দিষ্ট কাজে আপনার মডেলের কর্মক্ষমতা উন্নত করতে পারে অথবা যখন নির্দেশনা পর্যাপ্ত না হয় এবং আপনার কাছে এমন কিছু উদাহরণ থাকে যা আপনার পছন্দসই আউটপুট প্রদর্শন করে তখন মডেলটিকে নির্দিষ্ট আউটপুট প্রয়োজনীয়তা মেনে চলতে সাহায্য করে। PaliGemma-এর মতো Gemma-ভিত্তিক মডেলগুলিতে প্রত্যাশিত ফলাফল তৈরি করতে ফাইন-টিউনিং প্রয়োজন।

এই নোটবুকে কী আছে?

এই নোটবুকটি big_vision থেকে মডেল রেফারেন্স বাস্তবায়ন ব্যবহার করে এবং দেখায় কিভাবে:

  • নির্ভরতা ইনস্টল করুন, এবং PaliGemma মডেল চেকপয়েন্ট এবং প্রশিক্ষণ ডেটা ডাউনলোড করুন
  • মডেলটি GPU ডিভাইসে লোড করুন
  • প্রশিক্ষণ এবং অনুমানের জন্য মডেলের ইনপুট প্রস্তুত করুন।
  • মডেলটি সূক্ষ্ম করুন
  • আউটপুট পরীক্ষা করুন

এই নোটবুকের প্রশিক্ষণ ডেটাতে 90 জোড়া ছবি এবং লম্বা ক্যাপশন রয়েছে যা সেগুলি বর্ণনা করে। এটিকে T4 কোল্যাব রানটাইমে চালানোর জন্য, আপনাকে কেবল ভাষা মডেলের মনোযোগ স্তরগুলিকে সূক্ষ্মভাবে সুরক্ষিত করতে হবে এবং অন্যান্য পরামিতিগুলিকে স্থির করতে হবে।

এই উদাহরণটি শুধুমাত্র শেখার উদ্দেশ্যে। বাস্তব ব্যবহারের ক্ষেত্রে, ডেটার পরিমাণ, প্রশিক্ষণযোগ্য পরামিতি, প্রশিক্ষণের ধাপ এবং হাইপার-পরামিতি এবং প্রাপ্ত ফলাফল উল্লেখযোগ্যভাবে ভিন্ন হতে পারে।

শুরু করার আগে

এই নোটবুকটি পড়ার আগে, আপনার পাইথন কোডের সাথে পরিচিত হওয়া উচিত, সেইসাথে বৃহৎ ভাষা মডেল (LLM) কত প্রশিক্ষিত হয়। আপনার JAX এর সাথে পরিচিত হওয়ার প্রয়োজন নেই, তবে JAX (অথবা Keras এর মতো অনুরূপ প্রযুক্তি) সম্পর্কে প্রাথমিক জ্ঞান উদাহরণ কোডটি পড়ার সময় সহায়ক।

সেটআপ

নিম্নলিখিত বিভাগগুলিতে PaliGemma মডেল ব্যবহার করার জন্য একটি নোটবুক পাওয়ার প্রাথমিক পদক্ষেপগুলি ব্যাখ্যা করা হয়েছে, যার মধ্যে রয়েছে মডেল অ্যাক্সেস, একটি API কী পাওয়া এবং নোটবুক রানটাইম কনফিগার করা।

PaliGemma-তে অ্যাক্সেস পান

প্রথমবার PaliGemma ব্যবহার করার আগে, আপনাকে নিম্নলিখিত ধাপগুলি সম্পন্ন করে Kaggle-এর মাধ্যমে মডেলটিতে অ্যাক্সেসের জন্য অনুরোধ করতে হবে:

  1. Kaggle- এ লগ ইন করুন, অথবা যদি আপনার ইতিমধ্যেই Kaggle অ্যাকাউন্ট না থাকে তবে একটি নতুন Kaggle অ্যাকাউন্ট তৈরি করুন।
  2. PaliGemma মডেল কার্ডে যান এবং অ্যাক্সেসের অনুরোধ করুন এ ক্লিক করুন।
  3. সম্মতি ফর্মটি পূরণ করুন এবং শর্তাবলী গ্রহণ করুন।

আপনার API কী কনফিগার করুন

PaliGemma ব্যবহার করার জন্য, আপনাকে আপনার Kaggle ব্যবহারকারীর নাম এবং একটি Kaggle API কী প্রদান করতে হবে।

একটি Kaggle API কী তৈরি করতে, Kaggle-এ আপনার সেটিংস পৃষ্ঠাটি খুলুন এবং Create New Token এ ক্লিক করুন। এটি আপনার API শংসাপত্র ধারণকারী একটি kaggle.json ফাইল ডাউনলোড শুরু করে।

তারপর, Colab-এ, বাম দিকের প্যানে Secrets (🔑) নির্বাচন করুন এবং আপনার Kaggle ব্যবহারকারীর নাম এবং Kaggle API কী যোগ করুন। আপনার ব্যবহারকারীর নাম KAGGLE_USERNAME নামে এবং আপনার API কী KAGGLE_KEY নামে সংরক্ষণ করুন।

রানটাইম নির্বাচন করুন

এই টিউটোরিয়ালটি সম্পূর্ণ করার জন্য, আপনার PaliGemma মডেলটি চালানোর জন্য পর্যাপ্ত রিসোর্স সহ একটি Colab রানটাইম থাকতে হবে। এই ক্ষেত্রে, আপনি একটি T4 GPU ব্যবহার করতে পারেন:

  1. Colab উইন্ডোর উপরের ডানদিকে, ▾ (অতিরিক্ত সংযোগ বিকল্প) ড্রপডাউন মেনুতে ক্লিক করুন।
  2. রানটাইম টাইপ পরিবর্তন করুন নির্বাচন করুন।
  3. হার্ডওয়্যার অ্যাক্সিলারেটরের অধীনে, T4 GPU নির্বাচন করুন।

পাইথন প্যাকেজ ইনস্টল করুন

KaggleHub ইনস্টল করতে নিচের সেলটি চালান।

pip install -U -q kagglehub

পরিবেশের ভেরিয়েবল সেট করুন

পরিবেশ ভেরিয়েবল এবং Kaggle লগইন সেট করুন।

import os
# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid out-of-memory due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

# This will prompt you to enter your Kaggle API token
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…
Kaggle credentials set.
Kaggle credentials successfully validated.

GitHub থেকে আপনার Colab নোটবুকে big_vision রিপোজিটরিটি ডাউনলোড করুন এবং নিম্নলিখিত কোডটি চালিয়ে big_vision সম্পর্কিত নির্ভরতা ইনস্টল করুন।

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.7/76.7 kB 2.8 MB/s eta 0:00:00

JAX এবং অন্যান্য নির্ভরতা আমদানি করুন

PaliGemma-এর জন্য প্রয়োজনীয় JAX এবং অন্যান্য নির্ভরতা, যেমন TensorFlow এবং NumPy, আমদানি করুন।

import base64
import functools
import html
import io
import os
import warnings

import jax
import jax.numpy as jnp
import numpy as np
import ml_collections

import tensorflow as tf
import sentencepiece

from IPython.core.display import display, HTML
from PIL import Image

# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns

# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding

# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

backend = jax.extend.backend.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")
JAX version:  0.7.2
JAX platform: gpu
JAX devices:  1

মডেলটি ডাউনলোড এবং কনফিগার করুন

এই ধাপে, আপনি মডেল চেকপয়েন্টটি ডাউনলোড করবেন এবং এটিকে কনফিগার করবেন যাতে আপনি পরে এটিকে সূক্ষ্ম-টিউন করতে পারেন। এই ধাপে আপনাকে দেখানো হবে কিভাবে মডেল প্যারামিটারগুলিকে TPU মেমোরিতে স্থানান্তর করতে হয়, যা সীমিত সংস্থান সহ ডিভাইসগুলিতে মডেলগুলিকে সূক্ষ্ম-টিউন করার জন্য কার্যকর।

মডেল চেকপয়েন্টটি ডাউনলোড করুন

PaliGemma মডেলের বেশ কিছু বৈচিত্র্য রয়েছে। এই টিউটোরিয়ালের জন্য, আপনি বেস JAX/FLAX PaliGemma 3B ওজন মডেল ব্যবহার করবেন।

নিম্নলিখিত কোডটি রান করে Kaggle থেকে মডেল চেকপয়েন্টটি ডাউনলোড করুন। এই প্রক্রিয়াটি সম্পন্ন হতে কয়েক মিনিট সময় লাগে।

import os
import kagglehub

# Use these for PaliGemma-2 3B 224px²
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "./paligemma2-3b-pt-224.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224"  # Path to fetch from Kaggle.

# Use these for PaliGemma 1:
# LLM_VARIANT = "gemma_2b"
# MODEL_PATH = "./paligemma-3b-pt-224.f16.npz"
# KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"

if not os.path.exists(MODEL_PATH):
  print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
  MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
  print(f"Model path: {MODEL_PATH}")

TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")

DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
  print("Downloading the dataset...")
  !gsutil -m -q cp -n -r gs://longcap100/ .
  print(f"Data path: {DATA_DIR}")
Downloading the checkpoint from Kaggle, this could take a few minutes....
Downloading to /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz...
100%|██████████| 5.65G/5.65G [00:54<00:00, 112MB/s]
Model path: /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz
Downloading the model tokenizer...
Copying gs://big_vision/paligemma_tokenizer.model...

- [1 files][  4.1 MiB/  4.1 MiB]                                                
Operation completed over 1 objects/4.1 MiB.                                      
Tokenizer path: ./paligemma_tokenizer.model
Downloading the dataset...
Data path: ./longcap100

মডেলটি কনফিগার করুন

আপনি যে মডেলটি ব্যবহার করতে যাচ্ছেন তা আসলে কনফিগার করা শুরু করার সময় এসেছে।

এই নোটবুকের জন্য, আপনার মডেলটি একটি T4 GPU-তে ফিট করতে সক্ষম হতে হবে। সীমিত রিসোর্স যেমন স্থানের সীমাবদ্ধতা থাকার অর্থ হল আপনার মডেলটি কীভাবে কনফিগার করা হয়েছে সে সম্পর্কে আপনাকে সচেতন থাকতে হবে।

যদি আপনি প্রতিটি প্যারামিটার সূক্ষ্মভাবে সুরক্ষিত করেন, তাহলে আপনার মডেলটি নোটবুক পরিবেশে চলতে সক্ষম হবে না। ফলস্বরূপ, নোটবুকের এই অংশে, আপনি আপনার মডেলটিকে এমনভাবে কনফিগার করবেন যাতে এটি কিছু প্যারামিটার হিমায়িত করতে পারে এবং কেবলমাত্র সেই প্যারামিটারগুলিকেই সূক্ষ্মভাবে সুরক্ষিত করতে পারে যেগুলি মডেলটির জন্য আপনাকে সঠিক ফলাফল দেওয়ার জন্য সত্যিই সূক্ষ্মভাবে সুরক্ষিত করা প্রয়োজন। LLM-তে, প্যারামিটারগুলিকে হিমায়িত বলা হয় যখন সেগুলি মডেলটিকে প্রশিক্ষণ দেওয়ার জন্য সক্রিয়ভাবে ব্যবহৃত হয় না।

আপনার মডেলটি কনফিগার করার জন্য, আপনাকে যা করতে হবে:

  • model_config কে FrozenConfigDict হিসেবে আরম্ভ করুন যাতে আপনি কিছু প্যারামিটার ফ্রিজ করতে পারেন এবং মেমোরির ব্যবহার কম রাখতে পারেন।
  • PaliGemma Model ক্লাসের একটি ইনস্ট্যান্স শুরু করুন, model_config এর কনফিগারেশন হিসেবে ব্যবহার করে।
  • মডেল প্যারামিটারগুলি RAM-তে লোড করুন
  • মডেল থেকে নমুনা আউটপুটগুলির জন্য একটি decode ফাংশন সংজ্ঞায়িত করুন

এই কক্ষের এই কোডটি সম্পূর্ণ হতে প্রায় এক মিনিট সময় নেয়।

# Define model

# IMPORTANT: Gemma-2 has a "final_logits_softcap" property. Set it to 0.0
# for better transfer results.
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)

# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

মডেল প্যারামিটারগুলিকে GPU/TPU মেমরিতে সরান

এখন আপনাকে মডেল প্যারামিটারগুলিকে GPU/TPU মেমোরিতে স্থানান্তর করতে হবে। প্রথমে, উপলব্ধ GPU গুলিতে প্যারামিটারগুলিকে খণ্ডন করুন, তারপর প্যারামিটারগুলি লোড করুন। এখানে, আপনি ক্রমানুসারে প্যারামিটারগুলি লোড করবেন। এই প্রক্রিয়াটি একসাথে লোড করার চেয়ে বেশি সময় নেয়, তবে এর জন্য এই নোটবুকে উপলব্ধ RAM এর চেয়ে বেশি RAM প্রয়োজন।

অবশেষে, প্রতিটি প্যারামিটার কোন ধরণের কাস্ট করা হয়েছে তা দেখার জন্য সমস্ত প্যারামিটার প্রিন্ট করুন। Frozen প্যারামিটারগুলি float16 হিসাবে রাখা হয়, যখন trainable প্যারামিটারগুলি float32 এ কাস্ট করা হয়। যখন আপনি তালিকাটি পরিদর্শন করবেন, তখন আপনি দেখতে পাবেন যে বেশিরভাগ প্যারামিটারগুলি হিমায়িত হয়ে গেছে এবং float16

# Create a pytree mask of the trainable params.
def is_trainable_param(name, param):  # pylint: disable=unused-argument
  if name.startswith("llm/layers/attn/"):  return True
  if name.startswith("llm/"):              return False
  if name.startswith("img/"):              return False
  raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)

# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))

data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  # Cast others to float16, since some GPUs don't support bf16.
  return jax.tree.map(lambda p, m: p.astype(jnp.float32)
                      if m else p.astype(jnp.float16),
                      params, trainable)

# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default.
# Instead, do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
  params[idx] = big_vision.utils.reshard(params[idx], sharding)
  params[idx] = maybe_cast_to_f32(params[idx], trainable)
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

# Print params to show what the model is made of.
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")

print(" == Model params == ")
parameter_overview(params)
== Model params == 
img/Transformer/encoder_norm/bias                                                (1152,)                float16
img/Transformer/encoder_norm/scale                                               (1152,)                float16
img/Transformer/encoderblock/LayerNorm_0/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_0/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias                             (27, 4304)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel                           (27, 1152, 4304)       float16
img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias                             (27, 1152)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel                           (27, 4304, 1152)       float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias             (27, 16, 72)           float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel           (27, 1152, 16, 72)     float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias             (27, 1152)             float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel           (27, 16, 72, 1152)     float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias           (27, 16, 72)           float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel         (27, 1152, 16, 72)     float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias           (27, 16, 72)           float16
img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel         (27, 1152, 16, 72)     float16
img/embedding/bias                                                               (1152,)                float16
img/embedding/kernel                                                             (14, 14, 3, 1152)      float16
img/head/bias                                                                    (2304,)                float16
img/head/kernel                                                                  (1152, 2304)           float16
img/pos_embedding                                                                (1, 256, 1152)         float16
llm/embedder/input_embedding                                                     (257152, 2304)         float16
llm/final_norm/scale                                                             (2304,)                float16
llm/layers/attn/attn_vec_einsum/w                                                (26, 8, 256, 2304)     float32
llm/layers/attn/kv_einsum/w                                                      (26, 2, 4, 2304, 256)  float32
llm/layers/attn/q_einsum/w                                                       (26, 8, 2304, 256)     float32
llm/layers/mlp/gating_einsum                                                     (26, 2, 2304, 9216)    float16
llm/layers/mlp/linear                                                            (26, 9216, 2304)       float16
llm/layers/post_attention_norm/scale                                             (26, 2304)             float16
llm/layers/post_ffw_norm/scale                                                   (26, 2304)             float16
llm/layers/pre_attention_norm/scale                                              (26, 2304)             float16
llm/layers/pre_ffw_norm/scale                                                    (26, 2304)             float16

মডেলটি সুর করার জন্য প্রস্তুত হোন

এখন যেহেতু আপনার মডেলটি কনফিগার করা হয়েছে, আপনি এটি টিউন করতে পারেন। এই ধাপে, আপনি আপনার মডেলের ইনপুটগুলির পাশাপাশি প্রশিক্ষণ এবং বৈধতা পুনরাবৃত্তিকারী তৈরি করবেন, প্রশিক্ষণের উদাহরণগুলি দেখবেন এবং প্রশিক্ষণ এবং বৈধতা লুপগুলি সংজ্ঞায়িত করবেন।

মডেল ইনপুট তৈরি করুন

আপনি যে মডেল চেকপয়েন্টটি ব্যবহার করছেন তা ইতিমধ্যেই বিভিন্ন আকৃতির অনুপাতের ছবিগুলির উপর প্রশিক্ষণপ্রাপ্ত হয়েছে যেগুলিকে 224x224 পিক্সেল আকারে পরিবর্তন করা হয়েছে, এবং টোকেনাইজড টেক্সটগুলি পরিচালনা করার জন্য।

নিচের কোডটি তিনটি ফাংশন সংজ্ঞায়িত করে যা আপনি পরবর্তী ধাপে মডেলের ইনপুট তৈরিতে ব্যবহার করবেন:

  • preprocess_image : ছবির ডেটা স্বাভাবিক করে তোলে। এই ক্ষেত্রে, প্রি-প্রসেসিং পাস-ইন করা ছবিটিকে গ্রেস্কেলে রূপান্তর করে, আলফা স্তরটি সরিয়ে দেয় এবং পাস-ইন করা ছবিটিকে মডেলের ইমেজ ইনপুটগুলির জন্য প্রয়োজনীয় আকারে (২২৪x২২৪ পিক্সেল) আকার দেয়।
  • preprocess_tokens : টোকেনগুলিকে বিভক্ত করে এবং একটি টোকেন একটি প্রিফিক্স নাকি সাফিক্স টোকেন তা চিহ্নিত করার জন্য ফ্ল্যাগ যোগ করে। এই ফ্ল্যাগগুলি পরবর্তীতে কোডে, প্রশিক্ষণ ধাপ এবং মূল্যায়ন লুপের সময় ব্যবহার করা হবে।
  • postprocess_tokens : সিকোয়েন্সের শেষের (EOS) টোকেনের সময় এবং/অথবা পরে অবশিষ্ট যেকোনো টোকেন সরিয়ে দেয় এবং অবশিষ্ট ডিকোড করা টোকেনগুলি ফেরত দেয়।
def preprocess_image(image, size=224):
  # Model has been trained to handle images of different aspects ratios
  # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
  # options are helpful to improve quality in some tasks.
  image = np.asarray(image)
  if image.ndim == 2:  # Convert image without last channel into greyscale.
    image = np.stack((image,)*3, axis=-1)
  image = image[..., :3]  # Remove alpha layer.
  assert image.shape[-1] == 3

  image = tf.constant(image)
  image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
  return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]

def preprocess_tokens(prefix, suffix=None, seqlen=None):
  # Model has been trained to handle tokenized text composed of a prefix with
  # full attention and a suffix with causal attention.
  separator = "\n"
  tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
  mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.
  mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.

  if suffix:
    suffix = tokenizer.encode(suffix, add_eos=True)
    tokens += suffix
    mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.
    mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.

  mask_input = [1] * len(tokens)    # 1 if it's a token, 0 if padding.
  if seqlen:
    padding = [0] * max(0, seqlen - len(tokens))
    tokens = tokens[:seqlen] + padding
    mask_ar = mask_ar[:seqlen] + padding
    mask_loss = mask_loss[:seqlen] + padding
    mask_input = mask_input[:seqlen] + padding

  return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

def postprocess_tokens(tokens):
  tokens = tokens.tolist()  # np.array to list[int]
  try:  # Remove tokens at and after EOS if any.
    eos_pos = tokens.index(tokenizer.eos_id())
    tokens = tokens[:eos_pos]
  except ValueError:
    pass
  return tokenizer.decode(tokens)

প্রশিক্ষণ এবং যাচাইকরণ পুনরাবৃত্তিকারী তৈরি করুন

দুটি ইটারেটর তৈরি করুন:

  • একটি প্রশিক্ষণ ইটারেটর যা প্রশিক্ষণ প্রক্রিয়াটিকে একসাথে সমস্ত ডেটা প্রক্রিয়াকরণের পরিবর্তে খণ্ড খণ্ডে ডেটার মধ্য দিয়ে যেতে দেয়।
    • এটি আপনাকে ব্যবহারের আগে কিছু ডেটা প্রি-প্রসেসিং করার অনুমতি দেয়।
  • একটি ভ্যালিডেশন ইটারেটর যা প্রশিক্ষণ প্রক্রিয়াটিকে ভ্যালিডেশন ডেটাসেটের উপর পুনরাবৃত্তি করতে দেয় যাতে দেখা যায় যে টিউন করা মডেলটি প্রদত্ত ফলাফলের সাথে কতটা ভালোভাবে সামঞ্জস্যপূর্ণ।
SEQLEN = 128

train_dataset = big_vision.datasets.jsonl.DataSource(
    os.path.join(DATA_DIR, "data_train90.jsonl"),
    fopen_keys={"image": DATA_DIR})

val_dataset = big_vision.datasets.jsonl.DataSource(
    os.path.join(DATA_DIR, "data_val10.jsonl"),
    fopen_keys={"image": DATA_DIR})


def train_data_iterator():
  """Never ending iterator over training examples."""
  # Shuffle examples and repeat so one can train for many epochs.
  dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()
  for example in dataset.as_numpy_iterator():
    image = Image.open(io.BytesIO(example["image"]))
    image = preprocess_image(image)

    prefix = "caption en"  # Could also be a different prefix per example.
    suffix = example["suffix"].decode().lower()
    tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)

    yield {
        "image": np.asarray(image),
        "text": np.asarray(tokens),
        "mask_ar": np.asarray(mask_ar),
        "mask_loss": np.asarray(mask_loss),
    }


def validation_data_iterator():
  """Single iterator over validation examples."""
  for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():
    image = Image.open(io.BytesIO(example["image"]))
    image = preprocess_image(image)

    prefix = "caption en"  # Could also be a different prefix per example.
    tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)

    yield {
        "image": np.asarray(image),
        "text": np.asarray(tokens),
        "mask_ar": np.asarray(mask_ar),
        "mask_input": np.asarray(mask_input),
    }

প্রশিক্ষণের উদাহরণ দেখুন

এই নোটবুকে, প্রশিক্ষণের তথ্যে 90টি ছবি রয়েছে যা ছবিতে যা দেখানো হয়েছে তার দীর্ঘ বর্ণনার সাথে যুক্ত করা হয়েছে।

নিচের কোডটি প্রশিক্ষণ ডেটা সেট থেকে বর্ণনা সহ ছবিগুলির একটি এলোমেলো নির্বাচন প্রিন্ট করে যাতে আপনি দেখতে পারেন যে আপনার মডেলটি যে ছবি এবং বর্ণনাগুলিতে প্রশিক্ষণ পেয়েছে তা কেমন দেখাচ্ছে। প্রতিটি ছবি 128x128 পিক্সেল JPEG আকারে প্রদর্শিত হয়, ডানদিকে ছবির পাশে বর্ণনা মুদ্রিত থাকে।

def render_inline(image, resize=(128, 128)):
  """Convert image into inline html."""
  image = Image.fromarray(image)
  image.resize(resize)
  with io.BytesIO() as buffer:
    image.save(buffer, format='jpeg')
    image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
    return f"data:image/jpeg;base64,{image_b64}"

def render_example(image, caption):
  image = ((image + 1)/2 * 255).astype(np.uint8)  # [-1,1] -> [0, 255]
  return f"""
    <div style="display: inline-flex; align-items: center; justify-content: center;">
        <img style="width:128px; height:128px;" src="{render_inline(image, resize=(64,64))}" />
        <p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
    </div>
    """

html_out = ""
for idx, example in zip(range(8), train_data_iterator()):
  caption = postprocess_tokens(example["text"])  # detokenize model input.
  caption = caption[len("caption en\n"):]        # strip prefix
  html_out += render_example(example["image"], caption)

print("Training examples")
display(HTML(html_out))
Training examples

প্রশিক্ষণ এবং মূল্যায়ন চক্রগুলি সংজ্ঞায়িত করুন

প্রদত্ত ডেটাসেটে মডেলকে প্রশিক্ষণ দেওয়ার জন্য প্রশিক্ষণ লুপ এবং যাচাইকরণ ডেটাসেটের সমস্ত উদাহরণ দেখার এবং এর ভবিষ্যদ্বাণী করার জন্য মূল্যায়ন লুপ সংজ্ঞায়িত করুন।

প্রশিক্ষণ লুপ সংজ্ঞায়িত করা

update_fn ফাংশনটি প্রশিক্ষণ ধাপটি সংজ্ঞায়িত করে। প্রশিক্ষণ ধাপের সময়, প্রতি উদাহরণের ক্ষতি গণনা করা হয় এবং প্রশিক্ষণযোগ্য পরামিতিগুলিতে স্টোকাস্টিক গ্রেডিয়েন্ট ডিসেন্ট (SGD) প্রয়োগ করা হয়।

মনে রাখবেন যে নোটবুকের আগে, আপনি preprocess_tokens ফাংশনে ফ্ল্যাগ অন্তর্ভুক্ত করেছিলেন যার মধ্যে mask_loss অন্তর্ভুক্ত ছিল। আপনি এখানে mask_loss ফ্ল্যাগ ব্যবহার করে ক্ষতি থেকে প্রিফিক্স এবং প্যাডেড টোকেন বাদ দেবেন। এটি ছাড়া, ক্ষতির হিসাব বিকৃত হবে। আপনাকে প্রতিটি উদাহরণকে স্বাভাবিক করতে হবে, কারণ প্রতিটিতে টোকেনের সংখ্যা আলাদা। প্রিফিক্স এবং প্যাডেড টোকেন বাদ দেওয়ার পরে এবং উদাহরণগুলি স্বাভাবিক করার পরে, আপনি প্রতিটি উদাহরণের ক্ষতি গণনা করতে পারেন।

প্রশিক্ষণ ধাপে প্রশিক্ষণকে অপ্টিমাইজ করার জন্য একটি SGD প্রয়োগ করার একটি ফাংশনও অন্তর্ভুক্ত রয়েছে।

মূল্যায়ন লুপ সংজ্ঞায়িত করা

make_predictions ফাংশনটি হল আপনার মূল্যায়ন লুপ। মূল্যায়ন লুপটি মোটামুটি সোজা, একটি উল্লেখযোগ্য পরিবর্তন সহ। যদি আপনি নোটবুকের শুরু থেকে মনে করেন, আপনার প্রশিক্ষণ ডেটা সেটে মাত্র 90টি উদাহরণ রয়েছে। এটি প্রশিক্ষণের উদাহরণের একটি খুব কম সংখ্যা, এবং প্রশিক্ষণ চালানোর সময় আপনার মডেলটিতে ব্যাচের আকারের জন্য পর্যাপ্ত উদাহরণ থাকে না। এর অর্থ হল মূল্যায়ন লুপে, আপনাকে উদাহরণগুলি পুনরাবৃত্তি করে ব্যাচটি প্যাড করতে হবে।

আপনার মূল্যায়ন লুপটি কেবল প্রকৃত উদাহরণগুলিকে গণনা করে এবং প্যাডেড উদাহরণগুলিকে নয় তা নিশ্চিত করার জন্য, আপনাকে প্যাডেড উদাহরণগুলিতে একটি মাস্ক প্রয়োগ করতে হবে যা আউটপুট থেকে তাদের বাদ দেবে।

# The main update_fn using a simple stochastic gradient descent (SGD).
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
  imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]

  def loss_fn(params):
    text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
    logp = jax.nn.log_softmax(text_logits, axis=-1)

    # The model takes as input txts[:, :-1] but the loss is defined as predicting
    # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
    # are part of the loss (e.g. prefix and padded tokens are not included).
    mask_loss = batch["mask_loss"][:, 1:]
    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])

    # Compute the loss per example. i.e. the mean of per token pplx.
    # Since each example has a different number of tokens, normalize it.
    token_pplx = jnp.sum(logp * targets, axis=-1)  # sum across vocab_size.
    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)  # sum across seq_len.
    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)  # weight by num of tokens.

    # batch_loss: mean of per example loss.
    return jnp.mean(example_loss)

  loss, grads = jax.value_and_grad(loss_fn)(params)

  # Apply gradients to trainable params using SGD.
  def apply_grad(param, gradient, trainable):
    if not trainable: return param
    return param - learning_rate * gradient

  params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)

  return params, loss

# Evaluation/inference loop.
def make_predictions(data_iterator, *, num_examples=None,
                     batch_size=4, seqlen=SEQLEN, sampler="greedy"):
  outputs = []
  while True:
    # Construct a list of examples in the batch.
    examples = []
    try:
      for _ in range(batch_size):
        examples.append(next(data_iterator))
        examples[-1]["_mask"] = np.array(True)  # Indicates true example.
    except StopIteration:
      if len(examples) == 0:
        return outputs

    # Not enough examples to complete a batch. Pad by repeating last example.
    while len(examples) % batch_size:
      examples.append(dict(examples[-1]))
      examples[-1]["_mask"] = np.array(False)  # Indicates padding example.

    # Convert list of examples into a dict of np.arrays and load onto devices.
    batch = jax.tree.map(lambda *x: np.stack(x), *examples)
    batch = big_vision.utils.reshard(batch, data_sharding)

    # Make model predictions
    tokens = decode({"params": params}, batch=batch,
                    max_decode_len=seqlen, sampler=sampler)

    # Fetch model predictions to device and detokenize.
    tokens, mask = jax.device_get((tokens, batch["_mask"]))
    tokens = tokens[mask]  # remove padding examples.
    responses = [postprocess_tokens(t) for t in tokens]

    # Append to html output.
    for example, response in zip(examples, responses):
      outputs.append((example["image"], response))
      if num_examples and len(outputs) >= num_examples:
        return outputs

মডেলটি টিউন করুন

এখন আপনি সবকিছু সেট আপ করে ফেলেছেন এবং প্রশিক্ষণের ডেটা দেখেছেন, অবশেষে মডেলটি টিউন করার সময় এসেছে। নীচের কোডটি 64টি ধাপের জন্য মডেলের জন্য প্রশিক্ষণ লুপ চালায় এবং প্রতিটি ধাপের জন্য শেখার হার (মুদ্রিত আউটপুটে lr ) এবং ক্ষতির হার প্রিন্ট করে।

প্রতি ১৬ ধাপে, মডেলটি প্রশিক্ষণের সেই ধাপে তার ভবিষ্যদ্বাণীগুলি প্রিন্ট করে। এই কোডটি একই সেটের চিত্রের জন্য ভবিষ্যদ্বাণীগুলি প্রিন্ট করে যাতে আপনি দেখতে পারেন যে সময়ের সাথে সাথে মডেলের বর্ণনাগুলির ভবিষ্যদ্বাণী করার ক্ষমতা উন্নত হয়েছে।

প্রশিক্ষণের প্রথম ধাপগুলিতে, বর্ণনার সাথে সমস্যা হতে পারে, যেমন মডেলটি তার ভবিষ্যদ্বাণীমূলক লুপে আটকে গেলে বা অসম্পূর্ণ বাক্যে আটকে গেলে বারবার বাক্য বলা। প্রশিক্ষণের অগ্রগতির সাথে সাথে মডেলের ভবিষ্যদ্বাণীগুলি ধীরে ধীরে আরও নির্ভুল হয়ে ওঠে। ধাপ 64 অনুসারে, মডেলের ভবিষ্যদ্বাণীগুলি প্রশিক্ষণের তথ্য দ্বারা প্রদত্ত বর্ণনার সাথে ঘনিষ্ঠভাবে সাদৃশ্যপূর্ণ হওয়া উচিত।

T4 TPU-তে এই প্রক্রিয়াটি সম্পন্ন হতে প্রায় 15 মিনিট সময় লাগে।

# Run a short training loop with cosine learning rate schedule.
#
# Note: the first step can be quite slow on some machines (up to several minutes)
# due to XLA compilation of the jax.jit'd function.
#
%%time

BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.03

TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 4

train_data_it = train_data_iterator()

sched_fn = big_vision.utils.create_learning_rate_schedule(
    total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
    decay_type="cosine", warmup_percent=0.10)

for step in range(1, TRAIN_STEPS+1):
  # Make list of N training examples.
  examples = [next(train_data_it) for _ in range(BATCH_SIZE)]

  # Convert list of examples into a dict of np.arrays and load onto devices.
  batch = jax.tree.map(lambda *x: np.stack(x), *examples)
  batch = big_vision.utils.reshard(batch, data_sharding)

  # Training step and report training loss
  learning_rate = sched_fn(step)
  params, loss = update_fn(params, batch, learning_rate)

  loss = jax.device_get(loss)
  print(f"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}")

  if (step % EVAL_STEPS) == 0:
    print(f"Model predictions at step {step}")
    html_out = ""
    for image, caption in make_predictions(
        validation_data_iterator(), num_examples=4, batch_size=4):
      html_out += render_example(image, caption)
    display(HTML(html_out))
step:  1/64   lr: 0.00500   loss: 3.6567
step:  2/64   lr: 0.01000   loss: 1.9762
step:  3/64   lr: 0.01500   loss: 1.6299
step:  4/64   lr: 0.02000   loss: 1.5651
step:  5/64   lr: 0.02500   loss: 1.9813
step:  6/64   lr: 0.03000   loss: 1.9996
step:  7/64   lr: 0.02998   loss: 1.8595
step:  8/64   lr: 0.02992   loss: 1.6479
step:  9/64   lr: 0.02981   loss: 1.3693
step: 10/64   lr: 0.02966   loss: 1.3423
step: 11/64   lr: 0.02947   loss: 1.2122
step: 12/64   lr: 0.02924   loss: 1.0602
step: 13/64   lr: 0.02897   loss: 1.1314
step: 14/64   lr: 0.02866   loss: 1.2612
step: 15/64   lr: 0.02831   loss: 1.0132
step: 16/64   lr: 0.02792   loss: 1.2126
Model predictions at step 16
step: 17/64   lr: 0.02750   loss: 1.0986
step: 18/64   lr: 0.02704   loss: 0.9461
step: 19/64   lr: 0.02655   loss: 1.2098
step: 20/64   lr: 0.02602   loss: 1.0513
step: 21/64   lr: 0.02546   loss: 1.0979
step: 22/64   lr: 0.02488   loss: 0.9739
step: 23/64   lr: 0.02426   loss: 0.9589
step: 24/64   lr: 0.02362   loss: 0.7053
step: 25/64   lr: 0.02296   loss: 0.7347
step: 26/64   lr: 0.02227   loss: 0.6990
step: 27/64   lr: 0.02156   loss: 0.6736
step: 28/64   lr: 0.02083   loss: 0.6642
step: 29/64   lr: 0.02009   loss: 0.6908
step: 30/64   lr: 0.01933   loss: 0.7257
step: 31/64   lr: 0.01856   loss: 0.6902
step: 32/64   lr: 0.01778   loss: 0.7054
Model predictions at step 32
step: 33/64   lr: 0.01699   loss: 0.7709
step: 34/64   lr: 0.01620   loss: 0.6653
step: 35/64   lr: 0.01540   loss: 0.3811
step: 36/64   lr: 0.01460   loss: 0.3104
step: 37/64   lr: 0.01380   loss: 0.4042
step: 38/64   lr: 0.01301   loss: 0.3904
step: 39/64   lr: 0.01222   loss: 0.3339
step: 40/64   lr: 0.01144   loss: 0.4156
step: 41/64   lr: 0.01067   loss: 0.4085
step: 42/64   lr: 0.00991   loss: 0.3083
step: 43/64   lr: 0.00917   loss: 0.3757
step: 44/64   lr: 0.00844   loss: 0.3813
step: 45/64   lr: 0.00773   loss: 0.3381
step: 46/64   lr: 0.00704   loss: 0.2057
step: 47/64   lr: 0.00638   loss: 0.1287
step: 48/64   lr: 0.00574   loss: 0.1711
Model predictions at step 48
step: 49/64   lr: 0.00512   loss: 0.1183
step: 50/64   lr: 0.00454   loss: 0.1154
step: 51/64   lr: 0.00398   loss: 0.1967
step: 52/64   lr: 0.00345   loss: 0.1497
step: 53/64   lr: 0.00296   loss: 0.1688
step: 54/64   lr: 0.00250   loss: 0.1878
step: 55/64   lr: 0.00208   loss: 0.1865
step: 56/64   lr: 0.00169   loss: 0.1655
step: 57/64   lr: 0.00134   loss: 0.0911
step: 58/64   lr: 0.00103   loss: 0.1836
step: 59/64   lr: 0.00076   loss: 0.1242
step: 60/64   lr: 0.00053   loss: 0.0814
step: 61/64   lr: 0.00034   loss: 0.0866
step: 62/64   lr: 0.00019   loss: 0.1295
step: 63/64   lr: 0.00008   loss: 0.1053
step: 64/64   lr: 0.00002   loss: 0.0730
Model predictions at step 64
CPU times: user 2min 18s, sys: 8.98 s, total: 2min 27s
Wall time: 15min 45s

আউটপুট

এই নোটবুকের বৈধতা তথ্যে মাত্র ১০টি ছবি রয়েছে। সাধারণ কোডে, যাচাইকরণের জন্য আপনার কাছে সম্ভবত আরও অনেক ডেটা পয়েন্ট থাকবে, তবে এই নোটবুকের জন্য, সমস্ত ১০টি ছবির জন্য বর্ণনা তৈরি করতে নিম্নলিখিত কোডটি চালান। মডেলটি টিউন করার পরে, এই বিবরণগুলি ফর্ম এবং কন্টেন্ট কভারেজের দিক থেকে এই নোটবুকে আগে দেখা প্রশিক্ষণ তথ্যের সাথে অন্তর্ভুক্ত বর্ণনার সাথে খুব মিল হওয়া উচিত।

যাচাইকরণ ডেটা সেটের জন্য বর্ণনা তৈরি করতে নীচের কোডটি চালান।

# The validation data consists of 10 images in a different domain than training
# data.
%%time

print("Model predictions")
html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
  html_out += render_example(image, caption)
display(HTML(html_out))
Model predictions
CPU times: user 1.87 s, sys: 283 ms, total: 2.15 s
Wall time: 39.3 s