| | গুগল কোলাবে চালান | | | GitHub-এ উৎস দেখুন |
এই নোটবুকটি JAX ব্যবহার করে একটি ভিশন-ল্যাঙ্গুয়েজ টাস্কে PaliGemma-কে কীভাবে সূক্ষ্ম-টিউন করতে হয় তা দেখায়। ফাইন-টিউনিং এমন একটি প্রক্রিয়া যা নির্দিষ্ট কাজে আপনার মডেলের কর্মক্ষমতা উন্নত করতে পারে অথবা যখন নির্দেশনা পর্যাপ্ত না হয় এবং আপনার কাছে এমন কিছু উদাহরণ থাকে যা আপনার পছন্দসই আউটপুট প্রদর্শন করে তখন মডেলটিকে নির্দিষ্ট আউটপুট প্রয়োজনীয়তা মেনে চলতে সাহায্য করে। PaliGemma-এর মতো Gemma-ভিত্তিক মডেলগুলিতে প্রত্যাশিত ফলাফল তৈরি করতে ফাইন-টিউনিং প্রয়োজন।
এই নোটবুকে কী আছে?
এই নোটবুকটি big_vision থেকে মডেল রেফারেন্স বাস্তবায়ন ব্যবহার করে এবং দেখায় কিভাবে:
- নির্ভরতা ইনস্টল করুন, এবং PaliGemma মডেল চেকপয়েন্ট এবং প্রশিক্ষণ ডেটা ডাউনলোড করুন
- মডেলটি GPU ডিভাইসে লোড করুন
- প্রশিক্ষণ এবং অনুমানের জন্য মডেলের ইনপুট প্রস্তুত করুন।
- মডেলটি সূক্ষ্ম করুন
- আউটপুট পরীক্ষা করুন
এই নোটবুকের প্রশিক্ষণ ডেটাতে 90 জোড়া ছবি এবং লম্বা ক্যাপশন রয়েছে যা সেগুলি বর্ণনা করে। এটিকে T4 কোল্যাব রানটাইমে চালানোর জন্য, আপনাকে কেবল ভাষা মডেলের মনোযোগ স্তরগুলিকে সূক্ষ্মভাবে সুরক্ষিত করতে হবে এবং অন্যান্য পরামিতিগুলিকে স্থির করতে হবে।
এই উদাহরণটি শুধুমাত্র শেখার উদ্দেশ্যে। বাস্তব ব্যবহারের ক্ষেত্রে, ডেটার পরিমাণ, প্রশিক্ষণযোগ্য পরামিতি, প্রশিক্ষণের ধাপ এবং হাইপার-পরামিতি এবং প্রাপ্ত ফলাফল উল্লেখযোগ্যভাবে ভিন্ন হতে পারে।
শুরু করার আগে
এই নোটবুকটি পড়ার আগে, আপনার পাইথন কোডের সাথে পরিচিত হওয়া উচিত, সেইসাথে বৃহৎ ভাষা মডেল (LLM) কত প্রশিক্ষিত হয়। আপনার JAX এর সাথে পরিচিত হওয়ার প্রয়োজন নেই, তবে JAX (অথবা Keras এর মতো অনুরূপ প্রযুক্তি) সম্পর্কে প্রাথমিক জ্ঞান উদাহরণ কোডটি পড়ার সময় সহায়ক।
সেটআপ
নিম্নলিখিত বিভাগগুলিতে PaliGemma মডেল ব্যবহার করার জন্য একটি নোটবুক পাওয়ার প্রাথমিক পদক্ষেপগুলি ব্যাখ্যা করা হয়েছে, যার মধ্যে রয়েছে মডেল অ্যাক্সেস, একটি API কী পাওয়া এবং নোটবুক রানটাইম কনফিগার করা।
PaliGemma-তে অ্যাক্সেস পান
প্রথমবার PaliGemma ব্যবহার করার আগে, আপনাকে নিম্নলিখিত ধাপগুলি সম্পন্ন করে Kaggle-এর মাধ্যমে মডেলটিতে অ্যাক্সেসের জন্য অনুরোধ করতে হবে:
- Kaggle- এ লগ ইন করুন, অথবা যদি আপনার ইতিমধ্যেই Kaggle অ্যাকাউন্ট না থাকে তবে একটি নতুন Kaggle অ্যাকাউন্ট তৈরি করুন।
- PaliGemma মডেল কার্ডে যান এবং অ্যাক্সেসের অনুরোধ করুন এ ক্লিক করুন।
- সম্মতি ফর্মটি পূরণ করুন এবং শর্তাবলী গ্রহণ করুন।
আপনার API কী কনফিগার করুন
PaliGemma ব্যবহার করার জন্য, আপনাকে আপনার Kaggle ব্যবহারকারীর নাম এবং একটি Kaggle API কী প্রদান করতে হবে।
একটি Kaggle API কী তৈরি করতে, Kaggle-এ আপনার সেটিংস পৃষ্ঠাটি খুলুন এবং Create New Token এ ক্লিক করুন। এটি আপনার API শংসাপত্র ধারণকারী একটি kaggle.json ফাইল ডাউনলোড শুরু করে।
তারপর, Colab-এ, বাম দিকের প্যানে Secrets (🔑) নির্বাচন করুন এবং আপনার Kaggle ব্যবহারকারীর নাম এবং Kaggle API কী যোগ করুন। আপনার ব্যবহারকারীর নাম KAGGLE_USERNAME নামে এবং আপনার API কী KAGGLE_KEY নামে সংরক্ষণ করুন।
রানটাইম নির্বাচন করুন
এই টিউটোরিয়ালটি সম্পূর্ণ করার জন্য, আপনার PaliGemma মডেলটি চালানোর জন্য পর্যাপ্ত রিসোর্স সহ একটি Colab রানটাইম থাকতে হবে। এই ক্ষেত্রে, আপনি একটি T4 GPU ব্যবহার করতে পারেন:
- Colab উইন্ডোর উপরের ডানদিকে, ▾ (অতিরিক্ত সংযোগ বিকল্প) ড্রপডাউন মেনুতে ক্লিক করুন।
- রানটাইম টাইপ পরিবর্তন করুন নির্বাচন করুন।
- হার্ডওয়্যার অ্যাক্সিলারেটরের অধীনে, T4 GPU নির্বাচন করুন।
পাইথন প্যাকেজ ইনস্টল করুন
KaggleHub ইনস্টল করতে নিচের সেলটি চালান।
pip install -U -q kagglehubপরিবেশের ভেরিয়েবল সেট করুন
পরিবেশ ভেরিয়েবল এবং Kaggle লগইন সেট করুন।
import os
# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid out-of-memory due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
# This will prompt you to enter your Kaggle API token
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle… Kaggle credentials set. Kaggle credentials successfully validated.
big_vision সংগ্রহস্থলটি আনুন এবং সম্পর্কিত নির্ভরতা ইনস্টল করুন।
GitHub থেকে আপনার Colab নোটবুকে big_vision রিপোজিটরিটি ডাউনলোড করুন এবং নিম্নলিখিত কোডটি চালিয়ে big_vision সম্পর্কিত নির্ভরতা ইনস্টল করুন।
import os
import sys
# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
raise "It seems you are using Colab with remote TPUs which is not supported."
# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
!git clone --quiet --branch=main --depth=1 \
https://github.com/google-research/big_vision big_vision_repo
# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
sys.path.append("big_vision_repo")
# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.7/76.7 kB 2.8 MB/s eta 0:00:00
JAX এবং অন্যান্য নির্ভরতা আমদানি করুন
PaliGemma-এর জন্য প্রয়োজনীয় JAX এবং অন্যান্য নির্ভরতা, যেমন TensorFlow এবং NumPy, আমদানি করুন।
import base64
import functools
import html
import io
import os
import warnings
import jax
import jax.numpy as jnp
import numpy as np
import ml_collections
import tensorflow as tf
import sentencepiece
from IPython.core.display import display, HTML
from PIL import Image
# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding
# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")
backend = jax.extend.backend.get_backend()
print(f"JAX version: {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices: {jax.device_count()}")
JAX version: 0.7.2 JAX platform: gpu JAX devices: 1
মডেলটি ডাউনলোড এবং কনফিগার করুন
এই ধাপে, আপনি মডেল চেকপয়েন্টটি ডাউনলোড করবেন এবং এটিকে কনফিগার করবেন যাতে আপনি পরে এটিকে সূক্ষ্ম-টিউন করতে পারেন। এই ধাপে আপনাকে দেখানো হবে কিভাবে মডেল প্যারামিটারগুলিকে TPU মেমোরিতে স্থানান্তর করতে হয়, যা সীমিত সংস্থান সহ ডিভাইসগুলিতে মডেলগুলিকে সূক্ষ্ম-টিউন করার জন্য কার্যকর।
মডেল চেকপয়েন্টটি ডাউনলোড করুন
PaliGemma মডেলের বেশ কিছু বৈচিত্র্য রয়েছে। এই টিউটোরিয়ালের জন্য, আপনি বেস JAX/FLAX PaliGemma 3B ওজন মডেল ব্যবহার করবেন।
নিম্নলিখিত কোডটি রান করে Kaggle থেকে মডেল চেকপয়েন্টটি ডাউনলোড করুন। এই প্রক্রিয়াটি সম্পন্ন হতে কয়েক মিনিট সময় লাগে।
import os
import kagglehub
# Use these for PaliGemma-2 3B 224px²
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "./paligemma2-3b-pt-224.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224" # Path to fetch from Kaggle.
# Use these for PaliGemma 1:
# LLM_VARIANT = "gemma_2b"
# MODEL_PATH = "./paligemma-3b-pt-224.f16.npz"
# KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"
if not os.path.exists(MODEL_PATH):
print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
print(f"Model path: {MODEL_PATH}")
TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
print("Downloading the model tokenizer...")
!gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
print(f"Tokenizer path: {TOKENIZER_PATH}")
DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
print("Downloading the dataset...")
!gsutil -m -q cp -n -r gs://longcap100/ .
print(f"Data path: {DATA_DIR}")
Downloading the checkpoint from Kaggle, this could take a few minutes.... Downloading to /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz... 100%|██████████| 5.65G/5.65G [00:54<00:00, 112MB/s] Model path: /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz Downloading the model tokenizer... Copying gs://big_vision/paligemma_tokenizer.model... - [1 files][ 4.1 MiB/ 4.1 MiB] Operation completed over 1 objects/4.1 MiB. Tokenizer path: ./paligemma_tokenizer.model Downloading the dataset... Data path: ./longcap100
মডেলটি কনফিগার করুন
আপনি যে মডেলটি ব্যবহার করতে যাচ্ছেন তা আসলে কনফিগার করা শুরু করার সময় এসেছে।
এই নোটবুকের জন্য, আপনার মডেলটি একটি T4 GPU-তে ফিট করতে সক্ষম হতে হবে। সীমিত রিসোর্স যেমন স্থানের সীমাবদ্ধতা থাকার অর্থ হল আপনার মডেলটি কীভাবে কনফিগার করা হয়েছে সে সম্পর্কে আপনাকে সচেতন থাকতে হবে।
যদি আপনি প্রতিটি প্যারামিটার সূক্ষ্মভাবে সুরক্ষিত করেন, তাহলে আপনার মডেলটি নোটবুক পরিবেশে চলতে সক্ষম হবে না। ফলস্বরূপ, নোটবুকের এই অংশে, আপনি আপনার মডেলটিকে এমনভাবে কনফিগার করবেন যাতে এটি কিছু প্যারামিটার হিমায়িত করতে পারে এবং কেবলমাত্র সেই প্যারামিটারগুলিকেই সূক্ষ্মভাবে সুরক্ষিত করতে পারে যেগুলি মডেলটির জন্য আপনাকে সঠিক ফলাফল দেওয়ার জন্য সত্যিই সূক্ষ্মভাবে সুরক্ষিত করা প্রয়োজন। LLM-তে, প্যারামিটারগুলিকে হিমায়িত বলা হয় যখন সেগুলি মডেলটিকে প্রশিক্ষণ দেওয়ার জন্য সক্রিয়ভাবে ব্যবহৃত হয় না।
আপনার মডেলটি কনফিগার করার জন্য, আপনাকে যা করতে হবে:
-
model_configকেFrozenConfigDictহিসেবে আরম্ভ করুন যাতে আপনি কিছু প্যারামিটার ফ্রিজ করতে পারেন এবং মেমোরির ব্যবহার কম রাখতে পারেন। - PaliGemma
Modelক্লাসের একটি ইনস্ট্যান্স শুরু করুন,model_configএর কনফিগারেশন হিসেবে ব্যবহার করে। - মডেল প্যারামিটারগুলি RAM-তে লোড করুন
- মডেল থেকে নমুনা আউটপুটগুলির জন্য একটি
decodeফাংশন সংজ্ঞায়িত করুন
এই কক্ষের এই কোডটি সম্পূর্ণ হতে প্রায় এক মিনিট সময় নেয়।
# Define model
# IMPORTANT: Gemma-2 has a "final_logits_softcap" property. Set it to 0.0
# for better transfer results.
model_config = ml_collections.FrozenConfigDict({
"llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
"img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)
# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)
# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())
মডেল প্যারামিটারগুলিকে GPU/TPU মেমরিতে সরান
এখন আপনাকে মডেল প্যারামিটারগুলিকে GPU/TPU মেমোরিতে স্থানান্তর করতে হবে। প্রথমে, উপলব্ধ GPU গুলিতে প্যারামিটারগুলিকে খণ্ডন করুন, তারপর প্যারামিটারগুলি লোড করুন। এখানে, আপনি ক্রমানুসারে প্যারামিটারগুলি লোড করবেন। এই প্রক্রিয়াটি একসাথে লোড করার চেয়ে বেশি সময় নেয়, তবে এর জন্য এই নোটবুকে উপলব্ধ RAM এর চেয়ে বেশি RAM প্রয়োজন।
অবশেষে, প্রতিটি প্যারামিটার কোন ধরণের কাস্ট করা হয়েছে তা দেখার জন্য সমস্ত প্যারামিটার প্রিন্ট করুন। Frozen প্যারামিটারগুলি float16 হিসাবে রাখা হয়, যখন trainable প্যারামিটারগুলি float32 এ কাস্ট করা হয়। যখন আপনি তালিকাটি পরিদর্শন করবেন, তখন আপনি দেখতে পাবেন যে বেশিরভাগ প্যারামিটারগুলি হিমায়িত হয়ে গেছে এবং float16 ।
# Create a pytree mask of the trainable params.
def is_trainable_param(name, param): # pylint: disable=unused-argument
if name.startswith("llm/layers/attn/"): return True
if name.startswith("llm/"): return False
if name.startswith("img/"): return False
raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))
data_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("data"))
params_sharding = big_vision.sharding.infer_sharding(
params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)
# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable")
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
# Cast others to float16, since some GPUs don't support bf16.
return jax.tree.map(lambda p, m: p.astype(jnp.float32)
if m else p.astype(jnp.float16),
params, trainable)
# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default.
# Instead, do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
params[idx] = big_vision.utils.reshard(params[idx], sharding)
params[idx] = maybe_cast_to_f32(params[idx], trainable)
params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)
# Print params to show what the model is made of.
def parameter_overview(params):
for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")
print(" == Model params == ")
parameter_overview(params)
== Model params == img/Transformer/encoder_norm/bias (1152,) float16 img/Transformer/encoder_norm/scale (1152,) float16 img/Transformer/encoderblock/LayerNorm_0/bias (27, 1152) float16 img/Transformer/encoderblock/LayerNorm_0/scale (27, 1152) float16 img/Transformer/encoderblock/LayerNorm_1/bias (27, 1152) float16 img/Transformer/encoderblock/LayerNorm_1/scale (27, 1152) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias (27, 4304) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel (27, 1152, 4304) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias (27, 1152) float16 img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel (27, 4304, 1152) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias (27, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel (27, 1152, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias (27, 1152) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel (27, 16, 72, 1152) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias (27, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel (27, 1152, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias (27, 16, 72) float16 img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel (27, 1152, 16, 72) float16 img/embedding/bias (1152,) float16 img/embedding/kernel (14, 14, 3, 1152) float16 img/head/bias (2304,) float16 img/head/kernel (1152, 2304) float16 img/pos_embedding (1, 256, 1152) float16 llm/embedder/input_embedding (257152, 2304) float16 llm/final_norm/scale (2304,) float16 llm/layers/attn/attn_vec_einsum/w (26, 8, 256, 2304) float32 llm/layers/attn/kv_einsum/w (26, 2, 4, 2304, 256) float32 llm/layers/attn/q_einsum/w (26, 8, 2304, 256) float32 llm/layers/mlp/gating_einsum (26, 2, 2304, 9216) float16 llm/layers/mlp/linear (26, 9216, 2304) float16 llm/layers/post_attention_norm/scale (26, 2304) float16 llm/layers/post_ffw_norm/scale (26, 2304) float16 llm/layers/pre_attention_norm/scale (26, 2304) float16 llm/layers/pre_ffw_norm/scale (26, 2304) float16
মডেলটি সুর করার জন্য প্রস্তুত হোন
এখন যেহেতু আপনার মডেলটি কনফিগার করা হয়েছে, আপনি এটি টিউন করতে পারেন। এই ধাপে, আপনি আপনার মডেলের ইনপুটগুলির পাশাপাশি প্রশিক্ষণ এবং বৈধতা পুনরাবৃত্তিকারী তৈরি করবেন, প্রশিক্ষণের উদাহরণগুলি দেখবেন এবং প্রশিক্ষণ এবং বৈধতা লুপগুলি সংজ্ঞায়িত করবেন।
মডেল ইনপুট তৈরি করুন
আপনি যে মডেল চেকপয়েন্টটি ব্যবহার করছেন তা ইতিমধ্যেই বিভিন্ন আকৃতির অনুপাতের ছবিগুলির উপর প্রশিক্ষণপ্রাপ্ত হয়েছে যেগুলিকে 224x224 পিক্সেল আকারে পরিবর্তন করা হয়েছে, এবং টোকেনাইজড টেক্সটগুলি পরিচালনা করার জন্য।
নিচের কোডটি তিনটি ফাংশন সংজ্ঞায়িত করে যা আপনি পরবর্তী ধাপে মডেলের ইনপুট তৈরিতে ব্যবহার করবেন:
-
preprocess_image: ছবির ডেটা স্বাভাবিক করে তোলে। এই ক্ষেত্রে, প্রি-প্রসেসিং পাস-ইন করা ছবিটিকে গ্রেস্কেলে রূপান্তর করে, আলফা স্তরটি সরিয়ে দেয় এবং পাস-ইন করা ছবিটিকে মডেলের ইমেজ ইনপুটগুলির জন্য প্রয়োজনীয় আকারে (২২৪x২২৪ পিক্সেল) আকার দেয়। -
preprocess_tokens: টোকেনগুলিকে বিভক্ত করে এবং একটি টোকেন একটি প্রিফিক্স নাকি সাফিক্স টোকেন তা চিহ্নিত করার জন্য ফ্ল্যাগ যোগ করে। এই ফ্ল্যাগগুলি পরবর্তীতে কোডে, প্রশিক্ষণ ধাপ এবং মূল্যায়ন লুপের সময় ব্যবহার করা হবে। -
postprocess_tokens: সিকোয়েন্সের শেষের (EOS) টোকেনের সময় এবং/অথবা পরে অবশিষ্ট যেকোনো টোকেন সরিয়ে দেয় এবং অবশিষ্ট ডিকোড করা টোকেনগুলি ফেরত দেয়।
def preprocess_image(image, size=224):
# Model has been trained to handle images of different aspects ratios
# resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
# options are helpful to improve quality in some tasks.
image = np.asarray(image)
if image.ndim == 2: # Convert image without last channel into greyscale.
image = np.stack((image,)*3, axis=-1)
image = image[..., :3] # Remove alpha layer.
assert image.shape[-1] == 3
image = tf.constant(image)
image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]
def preprocess_tokens(prefix, suffix=None, seqlen=None):
# Model has been trained to handle tokenized text composed of a prefix with
# full attention and a suffix with causal attention.
separator = "\n"
tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.
mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.
if suffix:
suffix = tokenizer.encode(suffix, add_eos=True)
tokens += suffix
mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.
mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.
mask_input = [1] * len(tokens) # 1 if it's a token, 0 if padding.
if seqlen:
padding = [0] * max(0, seqlen - len(tokens))
tokens = tokens[:seqlen] + padding
mask_ar = mask_ar[:seqlen] + padding
mask_loss = mask_loss[:seqlen] + padding
mask_input = mask_input[:seqlen] + padding
return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))
def postprocess_tokens(tokens):
tokens = tokens.tolist() # np.array to list[int]
try: # Remove tokens at and after EOS if any.
eos_pos = tokens.index(tokenizer.eos_id())
tokens = tokens[:eos_pos]
except ValueError:
pass
return tokenizer.decode(tokens)
প্রশিক্ষণ এবং যাচাইকরণ পুনরাবৃত্তিকারী তৈরি করুন
দুটি ইটারেটর তৈরি করুন:
- একটি প্রশিক্ষণ ইটারেটর যা প্রশিক্ষণ প্রক্রিয়াটিকে একসাথে সমস্ত ডেটা প্রক্রিয়াকরণের পরিবর্তে খণ্ড খণ্ডে ডেটার মধ্য দিয়ে যেতে দেয়।
- এটি আপনাকে ব্যবহারের আগে কিছু ডেটা প্রি-প্রসেসিং করার অনুমতি দেয়।
- একটি ভ্যালিডেশন ইটারেটর যা প্রশিক্ষণ প্রক্রিয়াটিকে ভ্যালিডেশন ডেটাসেটের উপর পুনরাবৃত্তি করতে দেয় যাতে দেখা যায় যে টিউন করা মডেলটি প্রদত্ত ফলাফলের সাথে কতটা ভালোভাবে সামঞ্জস্যপূর্ণ।
SEQLEN = 128
train_dataset = big_vision.datasets.jsonl.DataSource(
os.path.join(DATA_DIR, "data_train90.jsonl"),
fopen_keys={"image": DATA_DIR})
val_dataset = big_vision.datasets.jsonl.DataSource(
os.path.join(DATA_DIR, "data_val10.jsonl"),
fopen_keys={"image": DATA_DIR})
def train_data_iterator():
"""Never ending iterator over training examples."""
# Shuffle examples and repeat so one can train for many epochs.
dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()
for example in dataset.as_numpy_iterator():
image = Image.open(io.BytesIO(example["image"]))
image = preprocess_image(image)
prefix = "caption en" # Could also be a different prefix per example.
suffix = example["suffix"].decode().lower()
tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_loss": np.asarray(mask_loss),
}
def validation_data_iterator():
"""Single iterator over validation examples."""
for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():
image = Image.open(io.BytesIO(example["image"]))
image = preprocess_image(image)
prefix = "caption en" # Could also be a different prefix per example.
tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_input": np.asarray(mask_input),
}
প্রশিক্ষণের উদাহরণ দেখুন
এই নোটবুকে, প্রশিক্ষণের তথ্যে 90টি ছবি রয়েছে যা ছবিতে যা দেখানো হয়েছে তার দীর্ঘ বর্ণনার সাথে যুক্ত করা হয়েছে।
নিচের কোডটি প্রশিক্ষণ ডেটা সেট থেকে বর্ণনা সহ ছবিগুলির একটি এলোমেলো নির্বাচন প্রিন্ট করে যাতে আপনি দেখতে পারেন যে আপনার মডেলটি যে ছবি এবং বর্ণনাগুলিতে প্রশিক্ষণ পেয়েছে তা কেমন দেখাচ্ছে। প্রতিটি ছবি 128x128 পিক্সেল JPEG আকারে প্রদর্শিত হয়, ডানদিকে ছবির পাশে বর্ণনা মুদ্রিত থাকে।
def render_inline(image, resize=(128, 128)):
"""Convert image into inline html."""
image = Image.fromarray(image)
image.resize(resize)
with io.BytesIO() as buffer:
image.save(buffer, format='jpeg')
image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
return f"data:image/jpeg;base64,{image_b64}"
def render_example(image, caption):
image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]
return f"""
<div style="display: inline-flex; align-items: center; justify-content: center;">
<img style="width:128px; height:128px;" src="{render_inline(image, resize=(64,64))}" />
<p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
</div>
"""
html_out = ""
for idx, example in zip(range(8), train_data_iterator()):
caption = postprocess_tokens(example["text"]) # detokenize model input.
caption = caption[len("caption en\n"):] # strip prefix
html_out += render_example(example["image"], caption)
print("Training examples")
display(HTML(html_out))
Training examples
প্রশিক্ষণ এবং মূল্যায়ন চক্রগুলি সংজ্ঞায়িত করুন
প্রদত্ত ডেটাসেটে মডেলকে প্রশিক্ষণ দেওয়ার জন্য প্রশিক্ষণ লুপ এবং যাচাইকরণ ডেটাসেটের সমস্ত উদাহরণ দেখার এবং এর ভবিষ্যদ্বাণী করার জন্য মূল্যায়ন লুপ সংজ্ঞায়িত করুন।
প্রশিক্ষণ লুপ সংজ্ঞায়িত করা
update_fn ফাংশনটি প্রশিক্ষণ ধাপটি সংজ্ঞায়িত করে। প্রশিক্ষণ ধাপের সময়, প্রতি উদাহরণের ক্ষতি গণনা করা হয় এবং প্রশিক্ষণযোগ্য পরামিতিগুলিতে স্টোকাস্টিক গ্রেডিয়েন্ট ডিসেন্ট (SGD) প্রয়োগ করা হয়।
মনে রাখবেন যে নোটবুকের আগে, আপনি preprocess_tokens ফাংশনে ফ্ল্যাগ অন্তর্ভুক্ত করেছিলেন যার মধ্যে mask_loss অন্তর্ভুক্ত ছিল। আপনি এখানে mask_loss ফ্ল্যাগ ব্যবহার করে ক্ষতি থেকে প্রিফিক্স এবং প্যাডেড টোকেন বাদ দেবেন। এটি ছাড়া, ক্ষতির হিসাব বিকৃত হবে। আপনাকে প্রতিটি উদাহরণকে স্বাভাবিক করতে হবে, কারণ প্রতিটিতে টোকেনের সংখ্যা আলাদা। প্রিফিক্স এবং প্যাডেড টোকেন বাদ দেওয়ার পরে এবং উদাহরণগুলি স্বাভাবিক করার পরে, আপনি প্রতিটি উদাহরণের ক্ষতি গণনা করতে পারেন।
প্রশিক্ষণ ধাপে প্রশিক্ষণকে অপ্টিমাইজ করার জন্য একটি SGD প্রয়োগ করার একটি ফাংশনও অন্তর্ভুক্ত রয়েছে।
মূল্যায়ন লুপ সংজ্ঞায়িত করা
make_predictions ফাংশনটি হল আপনার মূল্যায়ন লুপ। মূল্যায়ন লুপটি মোটামুটি সোজা, একটি উল্লেখযোগ্য পরিবর্তন সহ। যদি আপনি নোটবুকের শুরু থেকে মনে করেন, আপনার প্রশিক্ষণ ডেটা সেটে মাত্র 90টি উদাহরণ রয়েছে। এটি প্রশিক্ষণের উদাহরণের একটি খুব কম সংখ্যা, এবং প্রশিক্ষণ চালানোর সময় আপনার মডেলটিতে ব্যাচের আকারের জন্য পর্যাপ্ত উদাহরণ থাকে না। এর অর্থ হল মূল্যায়ন লুপে, আপনাকে উদাহরণগুলি পুনরাবৃত্তি করে ব্যাচটি প্যাড করতে হবে।
আপনার মূল্যায়ন লুপটি কেবল প্রকৃত উদাহরণগুলিকে গণনা করে এবং প্যাডেড উদাহরণগুলিকে নয় তা নিশ্চিত করার জন্য, আপনাকে প্যাডেড উদাহরণগুলিতে একটি মাস্ক প্রয়োগ করতে হবে যা আউটপুট থেকে তাদের বাদ দেবে।
# The main update_fn using a simple stochastic gradient descent (SGD).
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]
def loss_fn(params):
text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
logp = jax.nn.log_softmax(text_logits, axis=-1)
# The model takes as input txts[:, :-1] but the loss is defined as predicting
# next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
# are part of the loss (e.g. prefix and padded tokens are not included).
mask_loss = batch["mask_loss"][:, 1:]
targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])
# Compute the loss per example. i.e. the mean of per token pplx.
# Since each example has a different number of tokens, normalize it.
token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.
example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.
example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.
# batch_loss: mean of per example loss.
return jnp.mean(example_loss)
loss, grads = jax.value_and_grad(loss_fn)(params)
# Apply gradients to trainable params using SGD.
def apply_grad(param, gradient, trainable):
if not trainable: return param
return param - learning_rate * gradient
params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)
return params, loss
# Evaluation/inference loop.
def make_predictions(data_iterator, *, num_examples=None,
batch_size=4, seqlen=SEQLEN, sampler="greedy"):
outputs = []
while True:
# Construct a list of examples in the batch.
examples = []
try:
for _ in range(batch_size):
examples.append(next(data_iterator))
examples[-1]["_mask"] = np.array(True) # Indicates true example.
except StopIteration:
if len(examples) == 0:
return outputs
# Not enough examples to complete a batch. Pad by repeating last example.
while len(examples) % batch_size:
examples.append(dict(examples[-1]))
examples[-1]["_mask"] = np.array(False) # Indicates padding example.
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Make model predictions
tokens = decode({"params": params}, batch=batch,
max_decode_len=seqlen, sampler=sampler)
# Fetch model predictions to device and detokenize.
tokens, mask = jax.device_get((tokens, batch["_mask"]))
tokens = tokens[mask] # remove padding examples.
responses = [postprocess_tokens(t) for t in tokens]
# Append to html output.
for example, response in zip(examples, responses):
outputs.append((example["image"], response))
if num_examples and len(outputs) >= num_examples:
return outputs
মডেলটি টিউন করুন
এখন আপনি সবকিছু সেট আপ করে ফেলেছেন এবং প্রশিক্ষণের ডেটা দেখেছেন, অবশেষে মডেলটি টিউন করার সময় এসেছে। নীচের কোডটি 64টি ধাপের জন্য মডেলের জন্য প্রশিক্ষণ লুপ চালায় এবং প্রতিটি ধাপের জন্য শেখার হার (মুদ্রিত আউটপুটে lr ) এবং ক্ষতির হার প্রিন্ট করে।
প্রতি ১৬ ধাপে, মডেলটি প্রশিক্ষণের সেই ধাপে তার ভবিষ্যদ্বাণীগুলি প্রিন্ট করে। এই কোডটি একই সেটের চিত্রের জন্য ভবিষ্যদ্বাণীগুলি প্রিন্ট করে যাতে আপনি দেখতে পারেন যে সময়ের সাথে সাথে মডেলের বর্ণনাগুলির ভবিষ্যদ্বাণী করার ক্ষমতা উন্নত হয়েছে।
প্রশিক্ষণের প্রথম ধাপগুলিতে, বর্ণনার সাথে সমস্যা হতে পারে, যেমন মডেলটি তার ভবিষ্যদ্বাণীমূলক লুপে আটকে গেলে বা অসম্পূর্ণ বাক্যে আটকে গেলে বারবার বাক্য বলা। প্রশিক্ষণের অগ্রগতির সাথে সাথে মডেলের ভবিষ্যদ্বাণীগুলি ধীরে ধীরে আরও নির্ভুল হয়ে ওঠে। ধাপ 64 অনুসারে, মডেলের ভবিষ্যদ্বাণীগুলি প্রশিক্ষণের তথ্য দ্বারা প্রদত্ত বর্ণনার সাথে ঘনিষ্ঠভাবে সাদৃশ্যপূর্ণ হওয়া উচিত।
T4 TPU-তে এই প্রক্রিয়াটি সম্পন্ন হতে প্রায় 15 মিনিট সময় লাগে।
# Run a short training loop with cosine learning rate schedule.
#
# Note: the first step can be quite slow on some machines (up to several minutes)
# due to XLA compilation of the jax.jit'd function.
#
%%time
BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.03
TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 4
train_data_it = train_data_iterator()
sched_fn = big_vision.utils.create_learning_rate_schedule(
total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
decay_type="cosine", warmup_percent=0.10)
for step in range(1, TRAIN_STEPS+1):
# Make list of N training examples.
examples = [next(train_data_it) for _ in range(BATCH_SIZE)]
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Training step and report training loss
learning_rate = sched_fn(step)
params, loss = update_fn(params, batch, learning_rate)
loss = jax.device_get(loss)
print(f"step: {step:2d}/{TRAIN_STEPS:2d} lr: {learning_rate:.5f} loss: {loss:.4f}")
if (step % EVAL_STEPS) == 0:
print(f"Model predictions at step {step}")
html_out = ""
for image, caption in make_predictions(
validation_data_iterator(), num_examples=4, batch_size=4):
html_out += render_example(image, caption)
display(HTML(html_out))
step: 1/64 lr: 0.00500 loss: 3.6567 step: 2/64 lr: 0.01000 loss: 1.9762 step: 3/64 lr: 0.01500 loss: 1.6299 step: 4/64 lr: 0.02000 loss: 1.5651 step: 5/64 lr: 0.02500 loss: 1.9813 step: 6/64 lr: 0.03000 loss: 1.9996 step: 7/64 lr: 0.02998 loss: 1.8595 step: 8/64 lr: 0.02992 loss: 1.6479 step: 9/64 lr: 0.02981 loss: 1.3693 step: 10/64 lr: 0.02966 loss: 1.3423 step: 11/64 lr: 0.02947 loss: 1.2122 step: 12/64 lr: 0.02924 loss: 1.0602 step: 13/64 lr: 0.02897 loss: 1.1314 step: 14/64 lr: 0.02866 loss: 1.2612 step: 15/64 lr: 0.02831 loss: 1.0132 step: 16/64 lr: 0.02792 loss: 1.2126 Model predictions at step 16
step: 17/64 lr: 0.02750 loss: 1.0986 step: 18/64 lr: 0.02704 loss: 0.9461 step: 19/64 lr: 0.02655 loss: 1.2098 step: 20/64 lr: 0.02602 loss: 1.0513 step: 21/64 lr: 0.02546 loss: 1.0979 step: 22/64 lr: 0.02488 loss: 0.9739 step: 23/64 lr: 0.02426 loss: 0.9589 step: 24/64 lr: 0.02362 loss: 0.7053 step: 25/64 lr: 0.02296 loss: 0.7347 step: 26/64 lr: 0.02227 loss: 0.6990 step: 27/64 lr: 0.02156 loss: 0.6736 step: 28/64 lr: 0.02083 loss: 0.6642 step: 29/64 lr: 0.02009 loss: 0.6908 step: 30/64 lr: 0.01933 loss: 0.7257 step: 31/64 lr: 0.01856 loss: 0.6902 step: 32/64 lr: 0.01778 loss: 0.7054 Model predictions at step 32
step: 33/64 lr: 0.01699 loss: 0.7709 step: 34/64 lr: 0.01620 loss: 0.6653 step: 35/64 lr: 0.01540 loss: 0.3811 step: 36/64 lr: 0.01460 loss: 0.3104 step: 37/64 lr: 0.01380 loss: 0.4042 step: 38/64 lr: 0.01301 loss: 0.3904 step: 39/64 lr: 0.01222 loss: 0.3339 step: 40/64 lr: 0.01144 loss: 0.4156 step: 41/64 lr: 0.01067 loss: 0.4085 step: 42/64 lr: 0.00991 loss: 0.3083 step: 43/64 lr: 0.00917 loss: 0.3757 step: 44/64 lr: 0.00844 loss: 0.3813 step: 45/64 lr: 0.00773 loss: 0.3381 step: 46/64 lr: 0.00704 loss: 0.2057 step: 47/64 lr: 0.00638 loss: 0.1287 step: 48/64 lr: 0.00574 loss: 0.1711 Model predictions at step 48
step: 49/64 lr: 0.00512 loss: 0.1183 step: 50/64 lr: 0.00454 loss: 0.1154 step: 51/64 lr: 0.00398 loss: 0.1967 step: 52/64 lr: 0.00345 loss: 0.1497 step: 53/64 lr: 0.00296 loss: 0.1688 step: 54/64 lr: 0.00250 loss: 0.1878 step: 55/64 lr: 0.00208 loss: 0.1865 step: 56/64 lr: 0.00169 loss: 0.1655 step: 57/64 lr: 0.00134 loss: 0.0911 step: 58/64 lr: 0.00103 loss: 0.1836 step: 59/64 lr: 0.00076 loss: 0.1242 step: 60/64 lr: 0.00053 loss: 0.0814 step: 61/64 lr: 0.00034 loss: 0.0866 step: 62/64 lr: 0.00019 loss: 0.1295 step: 63/64 lr: 0.00008 loss: 0.1053 step: 64/64 lr: 0.00002 loss: 0.0730 Model predictions at step 64
CPU times: user 2min 18s, sys: 8.98 s, total: 2min 27s Wall time: 15min 45s
আউটপুট
এই নোটবুকের বৈধতা তথ্যে মাত্র ১০টি ছবি রয়েছে। সাধারণ কোডে, যাচাইকরণের জন্য আপনার কাছে সম্ভবত আরও অনেক ডেটা পয়েন্ট থাকবে, তবে এই নোটবুকের জন্য, সমস্ত ১০টি ছবির জন্য বর্ণনা তৈরি করতে নিম্নলিখিত কোডটি চালান। মডেলটি টিউন করার পরে, এই বিবরণগুলি ফর্ম এবং কন্টেন্ট কভারেজের দিক থেকে এই নোটবুকে আগে দেখা প্রশিক্ষণ তথ্যের সাথে অন্তর্ভুক্ত বর্ণনার সাথে খুব মিল হওয়া উচিত।
যাচাইকরণ ডেটা সেটের জন্য বর্ণনা তৈরি করতে নীচের কোডটি চালান।
# The validation data consists of 10 images in a different domain than training
# data.
%%time
print("Model predictions")
html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
html_out += render_example(image, caption)
display(HTML(html_out))
Model predictions
CPU times: user 1.87 s, sys: 283 ms, total: 2.15 s Wall time: 39.3 s
গুগল কোলাবে চালান
GitHub-এ উৎস দেখুন