استنتاج با CodeGemma با استفاده از JAX و Flax

مشاهده در ai.google.dev در Google Colab اجرا شود مشاهده منبع در GitHub

ما CodeGemma را ارائه می‌کنیم، مجموعه‌ای از مدل‌های کد باز مبتنی بر مدل‌های Gemma Google DeepMind (تیم Gemma و همکاران، 2024). CodeGemma خانواده ای از مدل های باز سبک وزن و پیشرفته است که از همان تحقیقات و فناوری استفاده شده برای ایجاد مدل های Gemini ساخته شده است.

با ادامه مدل‌های از پیش آموزش دیده Gemma، مدل‌های CodeGemma بر روی بیش از 500 تا 1000 میلیارد توکن کد اصلی، با استفاده از معماری‌های مشابه خانواده مدل Gemma آموزش داده می‌شوند. در نتیجه، مدل‌های CodeGemma به بهترین عملکرد کد در هر دو کار تکمیل و تولید دست می‌یابند، در حالی که مهارت‌های درک و استدلال قوی را در مقیاس حفظ می‌کنند.

CodeGemma دارای 3 نوع است:

  • یک مدل از پیش آموزش دیده با کد 7B
  • یک مدل کد تنظیم شده با دستورالعمل 7B
  • یک مدل 2B که به طور خاص برای تکمیل کد و تولید پایان باز آموزش داده شده است.

این راهنما شما را با استفاده از مدل CodeGemma با Flax برای یک کار تکمیل کد راهنمایی می کند.

راه اندازی

1. دسترسی Kaggle را برای CodeGemma تنظیم کنید

برای تکمیل این آموزش، ابتدا باید دستورالعمل های راه اندازی را در Gemma setup دنبال کنید، که به شما نشان می دهد چگونه کارهای زیر را انجام دهید:

  • در kaggle.com به CodeGemma دسترسی پیدا کنید.
  • یک زمان اجرا Colab با منابع کافی انتخاب کنید ( GPU T4 حافظه کافی ندارد، به جای آن از TPU v2 استفاده کنید ) برای اجرای مدل CodeGemma.
  • نام کاربری و کلید API Kaggle را ایجاد و پیکربندی کنید.

پس از تکمیل تنظیمات Gemma، به بخش بعدی بروید، جایی که متغیرهای محیطی را برای محیط Colab خود تنظیم خواهید کرد.

2. متغیرهای محیط را تنظیم کنید

متغیرهای محیطی را برای KAGGLE_USERNAME و KAGGLE_KEY تنظیم کنید. وقتی از شما خواسته شد "دسترسی بدهید؟" پیام‌ها، با ارائه دسترسی مخفی موافقت کنید.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. کتابخانه gemma نصب کنید

شتاب سخت‌افزار Free Colab در حال حاضر برای اجرای این نوت بوک کافی نیست . اگر از Colab Pay As You Go یا Colab Pro استفاده می‌کنید، روی Edit > تنظیمات نوت‌بوک > انتخاب A100 GPU > Save کلیک کنید تا شتاب سخت‌افزاری فعال شود.

در مرحله بعد، باید کتابخانه Google DeepMind gemma را از github.com/google-deepmind/gemma نصب کنید. اگر خطای «تحلیل کننده وابستگی پیپ» دریافت کردید، معمولاً می توانید آن را نادیده بگیرید.

pip install -q git+https://github.com/google-deepmind/gemma.git

4. واردات کتابخانه ها

این نوت بوک از Gemma (که از Flax برای ساخت لایه های شبکه عصبی خود استفاده می کند) و SentencePiece (برای توکن سازی) استفاده می کند.

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

مدل CodeGemma را بارگیری کنید

مدل CodeGemma را با kagglehub.model_download بارگیری کنید که سه آرگومان می گیرد:

  • handle : مدل دسته از Kaggle
  • path : (رشته اختیاری) مسیر محلی
  • force_download : (بولی اختیاری) بارگیری مجدد مدل را مجبور می کند
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

محل وزن های مدل و توکنایزر را بررسی کنید، سپس متغیرهای مسیر را تنظیم کنید. دایرکتوری توکنایزر در دایرکتوری اصلی جایی که مدل را دانلود کرده‌اید قرار می‌گیرد، در حالی که وزن‌های مدل در یک دایرکتوری فرعی قرار دارند. به عنوان مثال:

  • فایل توکنایزر spm.model در /LOCAL/PATH/TO/codegemma/flax/2b-pt/3 خواهد بود.
  • نقطه بازرسی مدل در /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt خواهد بود
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

انجام نمونه گیری/استنتاج

با متد gemma.params.load_and_format_params ، نقطه بازرسی مدل CodeGemma را بارگیری و قالب بندی کنید:

params = params_lib.load_and_format_params(CKPT_PATH)

توکنایزر CodeGemma را که با استفاده از sentencepiece.SentencePieceProcessor ساخته شده است بارگیری کنید.SentencePieceProcessor:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

برای بارگیری خودکار پیکربندی صحیح از نقطه بازرسی مدل CodeGemma، از gemma.transformer.TransformerConfig استفاده کنید. آرگومان cache_size تعداد مراحل زمانی در حافظه پنهان CodeGemma Transformer است. پس از آن، مدل CodeGemma را به عنوان model_2b با gemma.transformer.Transformer (که از flax.linen.Module به ارث می برد) نمونه سازی کنید.

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

یک sampler با gemma.sampler.Sampler ایجاد کنید. از نقطه بازرسی مدل CodeGemma و توکنایزر استفاده می کند.

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

چند متغیر برای نشان دادن نشانه های fill-in-the-middle (fim) ایجاد کنید و برخی از توابع کمکی را برای قالب بندی خروجی اعلان و تولید شده ایجاد کنید.

برای مثال، بیایید به کد زیر نگاه کنیم:

def function(string):
assert function('asdf') == 'fdsa'

ما می خواهیم function را به گونه ای پر کنیم که ادعا True باشد. در این مورد، پیشوند این خواهد بود:

"def function(string):\n"

و پسوند این خواهد بود:

"assert function('asdf') == 'fdsa'"

سپس این را به یک دستور به عنوان PREFIX-SUFFIX-MIDDLE قالب بندی می کنیم (بخش میانی که باید پر شود همیشه در انتهای اعلان است):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

یک اعلان ایجاد کنید و استنتاج را انجام دهید. پیشوند را before متن و پسوند after متن را مشخص کنید و با استفاده از format_completion prompt دستور فرمت شده را ایجاد کنید.

می‌توانید total_generation_steps را تغییر دهید (تعداد مراحل انجام شده هنگام ایجاد یک پاسخ - این مثال 100 برای حفظ حافظه میزبان استفاده می‌کند).

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

بیشتر بدانید

  • می‌توانید درباره کتابخانه gemma Google DeepMind در GitHub اطلاعات بیشتری کسب کنید، که شامل رشته‌های مستند ماژول‌هایی است که در این آموزش استفاده کرده‌اید، مانند gemma.params ، gemma.transformer ، و gemma.sampler .
  • کتابخانه‌های زیر سایت‌های مستند خود را دارند: core JAX ، Flax ، و Orbax .
  • برای مستندات توکنایزر/دتوکن‌سازی‌کننده sentencepiece ، از مخزن sentencepiece GitHub Google دیدن کنید.
  • برای مستندات kagglehub ، README.md در مخزن kagglehub GitHub بررسی کنید.
  • نحوه استفاده از مدل‌های Gemma با هوش مصنوعی Google Cloud Vertex را بیاموزید.
  • اگر از Google Cloud TPU (نسخه 3-8 و جدیدتر) استفاده می‌کنید، حتماً به آخرین بسته jax[tpu] نیز به‌روزرسانی کنید ( !pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html )، زمان اجرا را مجدداً راه اندازی کنید و بررسی کنید که نسخه های jax و jaxlib مطابقت دارند ( !pip list | grep jax ). این می تواند از RuntimeError که به دلیل عدم تطابق نسخه jaxlib و jax ایجاد می شود جلوگیری کند. برای دستورالعمل‌های نصب JAX بیشتر، به اسناد JAX مراجعه کنید.