استنتاج با RecurrentGemma با استفاده از JAX و Flax

مشاهده در ai.google.dev در Google Colab اجرا شود در Vertex AI باز کنید مشاهده منبع در GitHub

این آموزش نحوه انجام نمونه برداری/استنتاج اولیه با مدل RecurrentGemma 2B Instruct را با استفاده از کتابخانه recurrentgemma Google DeepMind که با JAX (یک کتابخانه محاسباتی عددی با کارایی بالا)، Flax (کتابخانه شبکه عصبی مبتنی بر JAX)، Orbax (یک) نوشته شده را نشان می دهد. کتابخانه مبتنی بر JAX برای ابزارهای آموزشی مانند checkpointing) و SentencePiece (یک کتابخانه توکنایزر/دتوکنیزر). اگرچه از Flax مستقیماً در این نوت بوک استفاده نمی شود، از Flax برای ایجاد Gemma و RecurrentGemma (مدل گریفین) استفاده شد.

این نوت بوک می تواند در Google Colab با پردازنده گرافیکی T4 اجرا شود (به Edit > تنظیمات نوت بوک > زیر شتاب دهنده سخت افزار، T4 GPU را انتخاب کنید).

راه اندازی

بخش‌های زیر مراحل آماده‌سازی یک نوت‌بوک برای استفاده از مدل RecurrentGemma، از جمله دسترسی به مدل، دریافت کلید API و پیکربندی زمان اجرا نوت‌بوک را توضیح می‌دهند.

دسترسی Kaggle را برای Gemma تنظیم کنید

برای تکمیل این آموزش، ابتدا باید دستورالعمل های راه اندازی مشابه راه اندازی Gemma را با چند استثنا دنبال کنید:

  • در kaggle.com به RecurrentGemma (به جای Gemma) دسترسی پیدا کنید.
  • یک زمان اجرا Colab با منابع کافی برای اجرای مدل RecurrentGemma انتخاب کنید.
  • نام کاربری و کلید API Kaggle را ایجاد و پیکربندی کنید.

پس از تکمیل تنظیمات RecurrentGemma، به بخش بعدی بروید، جایی که متغیرهای محیطی را برای محیط Colab خود تنظیم خواهید کرد.

تنظیم متغیرهای محیطی

متغیرهای محیطی را برای KAGGLE_USERNAME و KAGGLE_KEY تنظیم کنید. وقتی از شما خواسته شد "دسترسی بدهید؟" پیام‌ها، با ارائه دسترسی مخفی موافقت کنید.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

کتابخانه recurrentgemma را نصب کنید

این نوت بوک بر روی استفاده از یک GPU رایگان Colab تمرکز دارد. برای فعال کردن شتاب سخت افزاری، روی ویرایش > تنظیمات نوت بوک > انتخاب T4 GPU > ذخیره کلیک کنید.

در مرحله بعد، باید کتابخانه Google DeepMind recurrentgemma را از github.com/google-deepmind/recurrentgemma نصب کنید. اگر خطای «تحلیل کننده وابستگی پیپ» دریافت کردید، معمولاً می توانید آن را نادیده بگیرید.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

مدل RecurrentGemma را بارگیری و آماده کنید

  1. مدل RecurrentGemma را با kagglehub.model_download بارگیری کنید که سه آرگومان می گیرد:
  • handle : مدل دسته از Kaggle
  • path : (رشته اختیاری) مسیر محلی
  • force_download : (بولی اختیاری) بارگیری مجدد مدل را مجبور می کند
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. محل وزن های مدل و توکنایزر را بررسی کنید، سپس متغیرهای مسیر را تنظیم کنید. دایرکتوری توکنایزر در دایرکتوری اصلی جایی که مدل را دانلود کرده‌اید قرار می‌گیرد، در حالی که وزن‌های مدل در یک دایرکتوری فرعی قرار دارند. به عنوان مثال:
  • فایل tokenizer.model در /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1 خواهد بود.
  • نقطه بازرسی مدل در /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it خواهد بود.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

انجام نمونه گیری/استنتاج

  1. بازرسی مدل RecurrentGemma را با روش recurrentgemma.jax.load_parameters بارگیری کنید. آرگومان sharding روی "single_device" همه پارامترهای مدل را در یک دستگاه بارگیری می کند.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. توکنایزر مدل RecurrentGemma را بارگیری کنید که با استفاده از sentencepiece.SentencePieceProcessor ساخته شده است.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. برای بارگیری خودکار پیکربندی صحیح از نقطه بازرسی مدل RecurrentGemma، از recurrentgemma.GriffinConfig.from_flax_params_or_variables استفاده کنید. سپس، مدل Griffin را با recurrentgemma.jax.Griffin نمونه سازی کنید.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. یک sampler با recurrentgemma.jax.Sampler در بالای چکپوینت/وزن‌های مدل RecurrentGemma و توکنایزر ایجاد کنید:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. یک درخواست در prompt بنویسید و استنتاج انجام دهید. می‌توانید total_generation_steps تغییر دهید (تعداد مراحل انجام شده هنگام ایجاد پاسخ - این مثال از 50 برای حفظ حافظه میزبان استفاده می‌کند).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

بیشتر بدانید