JAX এবং Flax ব্যবহার করে RecurrentGemma এর সাথে অনুমান

ai.google.dev-এ দেখুন Google Colab-এ চালান Vertex AI-তে খুলুন GitHub-এ উৎস দেখুন

এই টিউটোরিয়ালটি দেখায় যে কীভাবে JAX (একটি উচ্চ-পারফরম্যান্স সংখ্যাসূচক কম্পিউটিং লাইব্রেরি), ফ্ল্যাক্স (JAX-ভিত্তিক নিউরাল নেটওয়ার্ক লাইব্রেরি), Orbax (a) দিয়ে লেখা Google DeepMind-এর recurrentgemma লাইব্রেরি ব্যবহার করে RecurrentGemma 2B Instruct মডেলের সাথে মৌলিক নমুনা/অনুমান সম্পাদন করতে হয়। চেকপয়েন্টিংয়ের মতো প্রশিক্ষণের জন্য JAX-ভিত্তিক লাইব্রেরি, এবং SentencePiece (একটি টোকেনাইজার/ডিটোকেনাইজার লাইব্রেরি)। যদিও এই নোটবুকে ফ্ল্যাক্স সরাসরি ব্যবহার করা হয়নি, ফ্ল্যাক্স জেমা এবং রিকারেন্টজেমা (গ্রিফিন মডেল) তৈরি করতে ব্যবহৃত হয়েছিল।

এই নোটবুকটি T4 GPU সহ Google Colab-এ চলতে পারে ( এডিট > নোটবুক সেটিংস > হার্ডওয়্যার এক্সিলারেটরের অধীনে T4 GPU নির্বাচন করুন)।

সেটআপ

নিম্নলিখিত বিভাগগুলি একটি RecurrentGemma মডেল ব্যবহার করার জন্য একটি নোটবুক প্রস্তুত করার পদক্ষেপগুলি ব্যাখ্যা করে, যার মধ্যে মডেল অ্যাক্সেস, একটি API কী পাওয়া এবং নোটবুক রানটাইম কনফিগার করা

জেমার জন্য কাগল অ্যাক্সেস সেট আপ করুন

এই টিউটোরিয়ালটি সম্পূর্ণ করতে, আপনাকে প্রথমে কয়েকটি ব্যতিক্রম সহ জেমা সেটআপের অনুরূপ সেটআপ নির্দেশাবলী অনুসরণ করতে হবে:

  • kaggle.com- এ RecurrentGemma (Gemma এর পরিবর্তে) অ্যাক্সেস পান।
  • RecurrentGemma মডেল চালানোর জন্য পর্যাপ্ত সম্পদ সহ একটি Colab রানটাইম বেছে নিন।
  • একটি Kaggle ব্যবহারকারীর নাম এবং API কী তৈরি এবং কনফিগার করুন।

আপনি RecurrentGemma সেটআপ সম্পূর্ণ করার পরে, পরবর্তী বিভাগে যান, যেখানে আপনি আপনার Colab পরিবেশের জন্য পরিবেশের ভেরিয়েবল সেট করবেন।

পরিবেশের ভেরিয়েবল সেট করুন

KAGGLE_USERNAME এবং KAGGLE_KEY জন্য পরিবেশ ভেরিয়েবল সেট করুন। যখন "অ্যাক্সেস মঞ্জুর করুন?" বার্তা, গোপন অ্যাক্সেস প্রদান করতে সম্মত হন।

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

recurrentgemma লাইব্রেরি ইনস্টল করুন

এই নোটবুকটি একটি ফ্রি Colab GPU ব্যবহার করার উপর ফোকাস করে। হার্ডওয়্যার ত্বরণ সক্ষম করতে, সম্পাদনা > নোটবুক সেটিংস > T4 GPU নির্বাচন করুন > সংরক্ষণ করুন এ ক্লিক করুন।

এরপরে, আপনাকে github.com/google-deepmind/recurrentgemma থেকে Google DeepMind recurrentgemma লাইব্রেরি ইনস্টল করতে হবে। আপনি যদি "পিপের নির্ভরতা সমাধানকারী" সম্পর্কে একটি ত্রুটি পান তবে আপনি সাধারণত এটি উপেক্ষা করতে পারেন।

pip install git+https://github.com/google-deepmind/recurrentgemma.git

RecurrentGemma মডেল লোড করুন এবং প্রস্তুত করুন

  1. kagglehub.model_download দিয়ে RecurrentGemma মডেল লোড করুন, যা তিনটি আর্গুমেন্ট নেয়:
  • handle : কাগল থেকে মডেল হ্যান্ডেল
  • path : (ঐচ্ছিক স্ট্রিং) স্থানীয় পথ
  • force_download : (ঐচ্ছিক বুলিয়ান) মডেলটিকে পুনরায় ডাউনলোড করতে বাধ্য করে
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. মডেল ওজন এবং টোকেনাইজারের অবস্থান পরীক্ষা করুন, তারপর পাথ ভেরিয়েবল সেট করুন। টোকেনাইজার ডিরেক্টরিটি মূল ডিরেক্টরিতে থাকবে যেখানে আপনি মডেলটি ডাউনলোড করেছেন, যখন মডেলের ওজনগুলি একটি সাব-ডিরেক্টরিতে থাকবে। যেমন:
  • tokenizer.model ফাইলটি /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1 ) এ থাকবে।
  • মডেল চেকপয়েন্টটি হবে /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it )।
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

নমুনা/অনুমান সম্পাদন করুন

  1. recurrentgemma.jax.load_parameters পদ্ধতির সাথে RecurrentGemma মডেল চেকপয়েন্ট লোড করুন। "single_device" তে সেট করা sharding আর্গুমেন্ট একটি একক ডিভাইসে সমস্ত মডেল প্যারামিটার লোড করে।
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. RecurrentGemma মডেল টোকেনাইজার লোড করুন, sentencepiece.SentencePieceProcessor ব্যবহার করে নির্মিত।SentencePieceProcessor :
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. RecurrentGemma মডেল চেকপয়েন্ট থেকে স্বয়ংক্রিয়ভাবে সঠিক কনফিগারেশন লোড করতে, recurrentgemma.GriffinConfig.from_flax_params_or_variables ব্যবহার করুন। তারপর, recurrentgemma.jax.Griffin দিয়ে Griffin মডেলটি ইনস্ট্যান্টিয়েট করুন।
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. recurrentgemma.jax.Sampler দিয়ে RecurrentGemma মডেলের চেকপয়েন্ট/ওজন এবং টোকেনাইজারের উপরে একটি sampler তৈরি করুন:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. prompt একটি প্রম্পট লিখুন এবং অনুমান সম্পাদন করুন। আপনি total_generation_steps tweak করতে পারেন (একটি প্রতিক্রিয়া তৈরি করার সময় সঞ্চালিত পদক্ষেপের সংখ্যা — এই উদাহরণটি হোস্ট মেমরি সংরক্ষণ করতে 50 ব্যবহার করে)।
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

আরও জানুন