हम CodeGemma पेश करते हैं. यह Google DeepMind के Gemma मॉडल (Gemma Team et al., 2024). CodeGemma एक लाइटवेट और बेहतरीन ओपन मॉडल है. इसे Gemini मॉडल में इस्तेमाल की गई रिसर्च और तकनीक का इस्तेमाल करके बनाया गया है.
Gemma के पहले से ट्रेन किए गए मॉडल के बाद, CodeGemma मॉडल को मुख्य रूप से कोड के 500 से 1,000 अरब से ज़्यादा टोकन पर ट्रेन किया जाता है. इसके लिए, Gemma मॉडल फ़ैमिली के जैसे ही आर्किटेक्चर का इस्तेमाल किया जाता है. इस वजह से, CodeGemma मॉडल कोड को पूरा करने और जनरेट करने, दोनों टास्क में बेहतरीन परफ़ॉर्मेंस देते हैं. साथ ही, बड़े पैमाने पर बेहतर समझ और तर्क करने की क्षमता बनाए रखते हैं.
CodeGemma के तीन वैरिएंट हैं:
- 7 अरब कोड वाला प्रीट्रेन किया गया मॉडल
- 7B निर्देशों के हिसाब से ट्यून किया गया कोड मॉडल
- यह एक 2B मॉडल है, जिसे खास तौर पर कोड में जानकारी भरने और ओपन-एंडेड जनरेशन के लिए ट्रेन किया गया है.
इस गाइड में, कोड पूरा करने के टास्क के लिए, Flax के साथ CodeGemma मॉडल का इस्तेमाल करने का तरीका बताया गया है.
सेटअप
1. CodeGemma के लिए Kaggle का ऐक्सेस सेट अप करना
इस ट्यूटोरियल को पूरा करने के लिए, आपको पहले Gemma सेटअप पर जाकर, सेटअप करने के निर्देशों का पालन करना होगा. इन निर्देशों में, आपको ये काम करने का तरीका बताया गया है:
- kaggle.com पर जाकर, CodeGemma का ऐक्सेस पाएं.
- CodeGemma मॉडल को चलाने के लिए, ज़रूरत के हिसाब से संसाधनों वाला Colab रनटाइम चुनें. T4 जीपीयू में ज़रूरत के हिसाब से मेमोरी नहीं है. इसके बजाय, TPU v2 का इस्तेमाल करें.
- Kaggle उपयोगकर्ता नाम और एपीआई पासकोड जनरेट और कॉन्फ़िगर करें.
Gemma का सेटअप पूरा करने के बाद, अगले सेक्शन पर जाएं. यहां आपको अपने Colab एनवायरमेंट के लिए, एनवायरमेंट वैरिएबल सेट करने होंगे.
2. एनवायरमेंट वैरिएबल सेट करना
KAGGLE_USERNAME
और KAGGLE_KEY
के लिए एनवायरमेंट वैरिएबल सेट करें. "ऐक्सेस दें?" मैसेज मिलने पर, सीक्रेट का ऐक्सेस देने के लिए सहमति दें.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. gemma
लाइब्रेरी इंस्टॉल करना
फ़िलहाल, Colab के बिना शुल्क वाले हार्डवेयर से तेज़ी लाने की सुविधा, इस नोटबुक को चलाने के लिए insufficient. अगर Colab के 'जितना इस्तेमाल करें, सिर्फ़ उतने पैसे चुकाएं' या Colab Pro वर्शन का इस्तेमाल किया जा रहा है, तो हार्डवेयर एक्सेलरेशन की सुविधा चालू करने के लिए, बदलाव करें > नोटबुक की सेटिंग > A100 जीपीयू चुनें > सेव करें पर क्लिक करें.
इसके बाद, आपको github.com/google-deepmind/gemma
से Google DeepMind gemma
लाइब्रेरी इंस्टॉल करनी होगी. अगर आपको "pip की डिपेंडेंसी रिज़ॉल्वर" के बारे में गड़बड़ी का कोई मैसेज मिलता है, तो आम तौर पर उसे अनदेखा किया जा सकता है.
pip install -q git+https://github.com/google-deepmind/gemma.git
4. लाइब्रेरी इंपोर्ट करना
यह नोटबुक, Gemma (जो अपने न्यूरल नेटवर्क लेयर बनाने के लिए Flax का इस्तेमाल करता है) और SentencePiece (टोकनाइज़ेशन के लिए) का इस्तेमाल करती है.
import os
from gemma.deprecated import params as params_lib
from gemma.deprecated import sampler as sampler_lib
from gemma.deprecated import transformer as transformer_lib
import sentencepiece as spm
CodeGemma मॉडल लोड करना
kagglehub.model_download
की मदद से CodeGemma मॉडल लोड करें. इसमें तीन आर्ग्युमेंट होते हैं:
handle
: Kaggle से मॉडल का हैंडलpath
: (ज़रूरी नहीं) स्थानीय पाथforce_download
: (ज़रूरी नहीं है, बूलियन) मॉडल को फिर से डाउनलोड करने के लिए मजबूर करता है
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
मॉडल के वेट और टॉकेनेटर की जगह देखें. इसके बाद, पाथ वैरिएबल सेट करें. tokenizer डायरेक्ट्री उस मुख्य डायरेक्ट्री में होगी जहां आपने मॉडल डाउनलोड किया था. वहीं, मॉडल के वेट किसी सब-डायरेक्ट्री में होंगे. उदाहरण के लिए:
spm.model
टॉकेनेटर फ़ाइल,/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
में होगी- मॉडल का चेकपॉइंट
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
में होगा
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
सैंपलिंग/अनुमान लगाना
gemma.params.load_and_format_params
तरीके से, CodeGemma मॉडल के चेकपॉइंट को लोड और फ़ॉर्मैट करें:
params = params_lib.load_and_format_params(CKPT_PATH)
sentencepiece.SentencePieceProcessor
का इस्तेमाल करके बनाया गया CodeGemma टोकनेटर लोड करें:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
CodeGemma मॉडल के चेकपॉइंट से सही कॉन्फ़िगरेशन अपने-आप लोड करने के लिए, gemma.deprecated.transformer.TransformerConfig
का इस्तेमाल करें. cache_size
आर्ग्युमेंट, CodeGemma Transformer
कैश मेमोरी में मौजूद टाइम स्टेप की संख्या है. इसके बाद, gemma.deprecated.transformer.Transformer
(जो flax.linen.Module
से इनहेरिट करता है) के साथ CodeGemma मॉडल को model_2b
के तौर पर इंस्टैंशिएट करें.
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
gemma.sampler.Sampler
की मदद से sampler
बनाएं. यह CodeGemma मॉडल के चेकपॉइंट और टोकनेटर का इस्तेमाल करता है.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
बीच में भरने के लिए दिए गए (फ़िल-इन-द-मिडल) टोकन दिखाने के लिए कुछ वैरिएबल बनाएं. साथ ही, प्रॉम्प्ट और जनरेट किए गए आउटपुट को फ़ॉर्मैट करने के लिए कुछ हेल्पर फ़ंक्शन बनाएं.
उदाहरण के लिए, यहां दिया गया कोड देखें:
def function(string):
assert function('asdf') == 'fdsa'
हम function
भरना चाहते हैं, ताकि दावा True
हो. इस मामले में, प्रीफ़िक्स यह होगा:
"def function(string):\n"
और सफ़िक्स यह होगा:
"assert function('asdf') == 'fdsa'"
इसके बाद, हम इसे प्रीफ़िक्स-सर्फ़िक्स-मिडल के तौर पर प्रॉम्प्ट में फ़ॉर्मैट करते हैं. प्रॉम्प्ट के आखिर में, वह मिडल सेक्शन होता है जिसे भरना होता है:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
प्रॉम्प्ट बनाना और अनुमान लगाना. प्रीफ़िक्स before
टेक्स्ट और सफ़िक्स after
टेक्स्ट डालें. साथ ही, हेल्पर फ़ंक्शन format_completion prompt
का इस्तेमाल करके, फ़ॉर्मैट किया गया प्रॉम्प्ट जनरेट करें.
total_generation_steps
में बदलाव किया जा सकता है. यह जवाब जनरेट करते समय किए गए चरणों की संख्या होती है. इस उदाहरण में, होस्ट मेमोरी को सुरक्षित रखने के लिए 100
का इस्तेमाल किया गया है.
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
ज़्यादा जानें
- GitHub पर Google DeepMind
gemma
लाइब्रेरी के बारे में ज़्यादा जानें. इसमें इस ट्यूटोरियल में इस्तेमाल किए गए मॉड्यूल के दस्तावेज़ शामिल हैं, जैसे किgemma.params
,gemma.deprecated.transformer
, औरgemma.sampler
. - इन लाइब्रेरी के दस्तावेज़ों की साइटें अलग-अलग हैं: core JAX, Flax, और Orbax.
sentencepiece
टॉकेनेटर/डेटॉकेनेटर के दस्तावेज़ के लिए, Google काsentencepiece
GitHub डेटा स्टोर करने की जगह देखें.kagglehub
के दस्तावेज़ के लिए, Kaggle केkagglehub
GitHub रेपो परREADME.md
देखें.- Google Cloud Vertex AI के साथ Gemma मॉडल इस्तेमाल करने का तरीका जानें.
- अगर Google Cloud TPUs (v3-8 और उसके बाद के वर्शन) का इस्तेमाल किया जा रहा है, तो पक्का करें कि आपने
jax[tpu]
के नए पैकेज (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
) पर भी अपडेट किया हो. साथ ही, रनटाइम को रीस्टार्ट करें और देखें किjax
औरjaxlib
वर्शन मैच करते हैं या नहीं (!pip list | grep jax
). इससे,jaxlib
औरjax
वर्शन के मैच न होने की वजह से होने वालीRuntimeError
से बचा जा सकता है. JAX इंस्टॉल करने के बारे में ज़्यादा जानने के लिए, JAX दस्तावेज़ देखें.