JAX और Flax का इस्तेमाल करके, RecurrentGemma का अनुमान

ai.google.dev पर देखें Google Colab में चलाएं Vertex AI में खोलें GitHub पर सोर्स देखें

यह ट्यूटोरियल, Google DeepMind की recurrentgemma लाइब्रेरी का इस्तेमाल करके, RecurrentGemma 2B इंस्ट्रक्शन मॉडल की मदद से बेसिक सैंपलिंग/अनुमान लगाने का तरीका बताता है. यह मॉडल JAX (एक बेहतर परफ़ॉर्मेंस वाली कंप्यूटिंग लाइब्रेरी), Flax (JAX-आधारित न्यूरल नेटवर्क लाइब्रेरी), Orbax (एक JAX-आधारित चेकिंग लाइब्रेरी) में लिखा गया है.SentencePiece हालांकि, इस notebook में Flax का इस्तेमाल सीधे तौर पर नहीं किया गया है. हालांकि, Gemma और RecurrentGemma (Giffin मॉडल) को बनाने के लिए, Flax का इस्तेमाल किया गया था.

इस notebook को Google Colab पर T4 जीपीयू के साथ चलाया जा सकता है (बदलाव करें > नोटबुक की सेटिंग पर जाएं > हार्डवेयर ऐक्सेलरेटर में जाकर, T4 जीपीयू चुनें.

सेटअप

नीचे दिए गए सेक्शन में, RecurrentGemma मॉडल का इस्तेमाल करने के लिए नोटबुक को तैयार करने का तरीका बताया गया है. इसमें मॉडल का ऐक्सेस, एपीआई पासकोड पाना, और notebook के रनटाइम को कॉन्फ़िगर करना शामिल है

Gemma के लिए Kaggle का ऐक्सेस सेट अप करें

इस ट्यूटोरियल को पूरा करने के लिए, आपको सबसे पहले कुछ अपवादों के साथ Gemma सेटअप से मिलते-जुलते सेटअप के निर्देशों का पालन करना होगा:

  • kaggle.com पर Gemma के बजाय RecurrentGemma का ऐक्सेस पाएं.
  • RecurrentGemma मॉडल को चलाने के लिए, ऐसे Colab रनटाइम चुनें जिसमें ज़रूरत के मुताबिक संसाधन हों.
  • Kaggle उपयोगकर्ता नाम और एपीआई पासकोड को जनरेट और कॉन्फ़िगर करें.

RecurrentGemma का सेटअप पूरा करने के बाद, अगले सेक्शन पर जाएं. यहां अपने Colab के एनवायरमेंट के लिए, एनवायरमेंट वैरिएबल सेट किए जा सकते हैं.

एनवायरमेंट वैरिएबल सेट करना

KAGGLE_USERNAME और KAGGLE_KEY के लिए, एनवायरमेंट वैरिएबल सेट करें. जब "ऐक्सेस दें?" के साथ प्रॉम्प्ट किया जाए मैसेज, सीक्रेट ऐक्सेस देने के लिए सहमत हों.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

recurrentgemma लाइब्रेरी इंस्टॉल करें

इस notebook में, बिना किसी शुल्क के Colab जीपीयू के इस्तेमाल पर फ़ोकस किया गया है. हार्डवेयर की मदद से तेज़ी लाने के लिए, बदलाव करें > पर क्लिक करें नोटबुक की सेटिंग > T4 जीपीयू चुनें > सेव करें पर टैप करें.

इसके बाद, आपको github.com/google-deepmind/recurrentgemma से Google DeepMind recurrentgemma लाइब्रेरी इंस्टॉल करनी होगी. अगर आपको "पीआईपी की डिपेंडेंसी रिज़ॉल्वर" से जुड़ी कोई गड़बड़ी मिलती है, तो आम तौर पर उसे अनदेखा किया जा सकता है.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

RecurrentGemma मॉडल लोड करें और तैयार करें

  1. kagglehub.model_download की मदद से RecurrentGemma मॉडल लोड करें, जिसमें तीन आर्ग्युमेंट होते हैं:
  • handle: Kaggle का मॉडल हैंडल
  • path: (वैकल्पिक स्ट्रिंग) लोकल पाथ
  • force_download: (वैकल्पिक बूलियन) मॉडल को फिर से डाउनलोड करने के लिए मजबूर करता है
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. मॉडल वेट और टोकनाइज़र की जगह की जांच करें. इसके बाद, पाथ वैरिएबल सेट करें. टोकनाइज़र डायरेक्ट्री, उस मुख्य डायरेक्ट्री में होगी जिसमें आपने मॉडल डाउनलोड किया है. वहीं, मॉडल वेट किसी सब-डायरेक्ट्री में होगा. उदाहरण के लिए:
  • tokenizer.model फ़ाइल /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1 में होगी).
  • मॉडल चेकपॉइंट /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it में होगा).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

सैंपलिंग/अनुमान लगाना

  1. recurrentgemma.jax.load_parameters तरीके की मदद से, RecurrentGemma मॉडल चेकपॉइंट को लोड करें. "single_device" पर सेट किया गया sharding आर्ग्युमेंट, एक ही डिवाइस पर सभी मॉडल पैरामीटर लोड करता है.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. sentencepiece.SentencePieceProcessor का इस्तेमाल करके बनाए गए RecurrentGemma मॉडल टोकनाइज़र को लोड करें:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. RecurrentGemma मॉडल चेकपॉइंट से सही कॉन्फ़िगरेशन को अपने-आप लोड करने के लिए, recurrentgemma.GriffinConfig.from_flax_params_or_variables का इस्तेमाल करें. इसके बाद, ग्रिफ़िन मॉडल को recurrentgemma.jax.Griffin से इंस्टैंशिएट करें.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. RecurrentGemma मॉडल चेकपॉइंट/वेट और टोकनाइज़र के ऊपर, recurrentgemma.jax.Sampler के साथ sampler बनाएं:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. prompt में प्रॉम्प्ट लिखें और अनुमान लगाएं. total_generation_steps में बदलाव किया जा सकता है. रिस्पॉन्स जनरेट करने के दौरान किए गए चरणों की संख्या में बदलाव किया जा सकता है. इस उदाहरण में, होस्ट की मेमोरी को बनाए रखने के लिए 50 का इस्तेमाल किया गया है.
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

ज़्यादा जानें