JAX और Flax का इस्तेमाल करके, CodeGemma का अनुमान लगाना

ai.google.dev पर देखें Google Colab में चलाएं GitHub पर सोर्स देखें

हम CodeGemma के बारे में बताते हैं. यह ओपन कोड मॉडल का कलेक्शन है, जो Google DeepMind के Gemma मॉडल पर आधारित है (Gemma Team et al., 2024). CodeGemma एक लाइटवेट और बेहतरीन ओपन मॉडल है. इसे Gemini मॉडल में इस्तेमाल की गई रिसर्च और टेक्नोलॉजी का इस्तेमाल करके बनाया गया है.

Gemma के पहले से ट्रेन किए गए मॉडल से शुरू किए गए CodeGemma मॉडल को मुख्य तौर पर कोड के 500 से 1, 000 अरब से ज़्यादा टोकन की मदद से ट्रेनिंग दी गई है. इनका इस्तेमाल करके, की बनावट, जेमा मॉडल परिवार की तरह है. इस वजह से, CodeGemma मॉडल, दोनों चरणों में सबसे नई कोड परफ़ॉर्मेंस हासिल करते हैं और जेनरेशन के टास्क बनाने के साथ-साथ, और तर्क के साथ सोचने और समझना आसान हो जाता है.

CodeGemma के तीन वैरिएंट हैं:

  • 7B कोड का पहले से ट्रेन किया गया मॉडल
  • निर्देशों के हिसाब से तैयार किया गया 7B कोड मॉडल
  • 2B मॉडल, जिसे खास तौर पर कोड डालने और ओपन-एंडेड जनरेशन के लिए ट्रेन किया गया है.

इस गाइड में, आपको कोड पूरा करने वाले टास्क के लिए, Flax के साथ CodeGemma मॉडल का इस्तेमाल करने का तरीका बताया गया है.

सेटअप

1. CodeGemma के लिए Kaggle का ऐक्सेस सेट अप करना

इस ट्यूटोरियल को पूरा करने के लिए, आपको सबसे पहले Gemma सेटअप में दिए गए सेटअप के निर्देशों का पालन करना होगा. इसमें बताया गया है कि ये काम कैसे किए जा सकते हैं:

  • kaggle.com पर CodeGemma का ऐक्सेस पाएं.
  • CodeGemma मॉडल को चलाने के लिए, ऐसे Colab रनटाइम चुनें जिसमें ज़रूरी संसाधन हों (T4 जीपीयू के पास ज़रूरत के मुताबिक मेमोरी नहीं है. इसके बजाय, TPU v2 का इस्तेमाल करें).
  • Kaggle उपयोगकर्ता नाम और एपीआई पासकोड को जनरेट और कॉन्फ़िगर करें.

Gemma का सेटअप पूरा करने के बाद, अगले सेक्शन पर जाएं. यहां अपने Colab एनवायरमेंट के लिए, एनवायरमेंट वैरिएबल सेट किए जा सकते हैं.

2. एनवायरमेंट वैरिएबल सेट करना

KAGGLE_USERNAME और KAGGLE_KEY के लिए, एनवायरमेंट वैरिएबल सेट करें. जब "ऐक्सेस दें?" के साथ प्रॉम्प्ट किया जाए मैसेज, सीक्रेट ऐक्सेस देने के लिए सहमत हों.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma लाइब्रेरी इंस्टॉल करें

इस notebook को चलाने के लिए, फ़िलहाल Colab के हार्डवेयर से तेज़ी लाने की सुविधा ज़रूरत के मुताबिक नहीं है. अगर Colab Pay As You Go या Colab Pro का इस्तेमाल किया जा रहा है, तो बदलाव करें पर क्लिक करें > Notebook की सेटिंग > A100 जीपीयू चुनें > हार्डवेयर की मदद से तेज़ी लाने के लिए, सेव करें.

इसके बाद, आपको github.com/google-deepmind/gemma से Google DeepMind gemma लाइब्रेरी इंस्टॉल करनी होगी. अगर आपको "पीआईपी की डिपेंडेंसी रिज़ॉल्वर" से जुड़ी कोई गड़बड़ी मिलती है, तो आम तौर पर उसे अनदेखा किया जा सकता है.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. लाइब्रेरी इंपोर्ट करें

इस notebook में Gemma का इस्तेमाल किया गया है. यह अपनी न्यूरल नेटवर्क लेयर बनाने के लिए, Flax का इस्तेमाल करता है. साथ ही, यह SentencePiece (टोकनाइज़ेशन के लिए) का इस्तेमाल करता है.

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

CodeGemma मॉडल को लोड करें

CodeGemma मॉडल को kagglehub.model_download से लोड करें. इसमें तीन आर्ग्युमेंट इस्तेमाल किए जाते हैं:

  • handle: Kaggle का मॉडल हैंडल
  • path: (वैकल्पिक स्ट्रिंग) लोकल पाथ
  • force_download: (ज़रूरी नहीं बूलियन) मॉडल को फिर से डाउनलोड करना पड़ता है
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

मॉडल वेट और टोकनाइज़र की जगह की जांच करें. इसके बाद, पाथ वैरिएबल सेट करें. टोकनाइज़र डायरेक्ट्री, उस मुख्य डायरेक्ट्री में होगी जिसमें आपने मॉडल डाउनलोड किया है. वहीं, मॉडल का वेट किसी सब-डायरेक्ट्री में होगा. उदाहरण के लिए:

  • spm.model टोकनाइज़र फ़ाइल, /LOCAL/PATH/TO/codegemma/flax/2b-pt/3 में होगी
  • मॉडल चेकपॉइंट /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt में होगा
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

सैंपलिंग/अनुमान लगाना

gemma.params.load_and_format_params तरीके का इस्तेमाल करके, CodeGemma मॉडल चेकपॉइंट को लोड और फ़ॉर्मैट करें:

params = params_lib.load_and_format_params(CKPT_PATH)

sentencepiece.SentencePieceProcessor का इस्तेमाल करके बनाया गया CodeGemma टोकनाइज़र लोड करें:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

CodeGemma मॉडल चेकपॉइंट से सही कॉन्फ़िगरेशन को अपने-आप लोड करने के लिए, gemma.transformer.TransformerConfig का इस्तेमाल करें. cache_size आर्ग्युमेंट, CodeGemma Transformer कैश में मौजूद समय के चरणों की संख्या होता है. इसके बाद, CodeGemma मॉडल को model_2b के तौर पर gemma.transformer.Transformer (जो flax.linen.Module से इनहेरिट होता है) के साथ इंस्टैंशिएट करें.

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

gemma.sampler.Sampler का इस्तेमाल करके sampler बनाएं. यह CodeGemma मॉडल चेकपॉइंट और टोकनाइज़र का इस्तेमाल करता है.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

फ़िल-इन-द-मिडल (fim) टोकन को दिखाने के लिए कुछ वैरिएबल बनाएं. साथ ही, प्रॉम्प्ट और जनरेट किए गए आउटपुट को फ़ॉर्मैट करने के लिए कुछ हेल्पर फ़ंक्शन बनाएं.

उदाहरण के लिए, आइए इस कोड को देखें:

def function(string):
assert function('asdf') == 'fdsa'

हम function को भरना चाहते हैं, ताकि दावा True को होल्ड पर रख सके. इस मामले में, प्रीफ़िक्स यह होगा:

"def function(string):\n"

और सफ़िक्स यह होगा:

"assert function('asdf') == 'fdsa'"

इसके बाद, हम इसे प्रॉम्प्ट के तौर पर PREFIX- सुरक्षा-MIDDLE के तौर पर फ़ॉर्मैट कर देते हैं (बीच का सेक्शन जिसे भरने की ज़रूरत होती है वह हमेशा प्रॉम्प्ट के आखिर में होता है):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

कोई प्रॉम्प्ट बनाएं और अनुमान लगाएं. प्रीफ़िक्स before टेक्स्ट और सफ़िक्स after टेक्स्ट तय करें. इसके बाद, हेल्पर फ़ंक्शन format_completion prompt का इस्तेमाल करके फ़ॉर्मैट किया गया प्रॉम्प्ट जनरेट करें.

total_generation_steps में बदलाव किया जा सकता है. रिस्पॉन्स जनरेट करने के दौरान किए गए चरणों की संख्या में बदलाव किया जा सकता है. इस उदाहरण में, होस्ट की मेमोरी को बनाए रखने के लिए 100 का इस्तेमाल किया गया है.

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

ज़्यादा जानें