الاستنتاج باستخدام RecurrentGemma باستخدام JAX وFlex

العرض على ai.google.dev التنفيذ في Google Colab الفتح في Vertex AI عرض المصدر على GitHub

يشرح هذا البرنامج التعليمي كيفية إجراء عملية أخذ العينات/الاستنتاج الأساسي باستخدام نموذج RecurrentGemma 2B Instructing باستخدام مكتبة recurrentgemma من Google DeepMind التي تمت كتابتها باستخدام JAX (مكتبة حوسبة رقمية عالية الأداء) وFlax (مكتبة الشبكة العصبية المستندة إلى JAX) وOrbax (مكتبة الرموز البرمجية المستندة إلى JAXSentencePiece على الرغم من عدم استخدام Flax مباشرةً في هذا الدفتر، فقد تم استخدامه لإنشاء Gemma وRecurrentGemma (نموذج Griffin).

يمكن تشغيل ورقة الملاحظات هذه على Google Colab باستخدام وحدة معالجة الرسومات T4 (انتقِل إلى تعديل > إعدادات ورقة الملاحظات > ضِمن مسرِّع الأجهزة، اختَر وحدة معالجة الرسومات T4).

ضبط إعدادات الجهاز

توضّح الأقسام التالية خطوات إعداد دفتر ملاحظات لاستخدام نموذج RecurrentGemma، بما في ذلك الوصول إلى النموذج والحصول على مفتاح واجهة برمجة التطبيقات وضبط وقت تشغيل ورقة الملاحظات

إعداد وصول Kaggle لـ Gemma

لإكمال هذا البرنامج التعليمي، عليك أولاً اتّباع تعليمات الإعداد المشابهة لإعداد Gemma، مع بعض الاستثناءات:

  • يمكنك الوصول إلى RecurrentGemma (بدلاً من Gemma) على kaggle.com.
  • اختَر بيئة تشغيل Colab بها موارد كافية لتشغيل نموذج RecurrentGemma.
  • إنشاء وتكوين اسم مستخدم ومفتاح واجهة برمجة تطبيقات Kaggle.

بعد الانتهاء من إعداد RecurrentGemma، انتقِل إلى القسم التالي، حيث يمكنك ضبط متغيّرات البيئة لبيئة Colab.

ضبط متغيرات البيئة

ضبط متغيّرات البيئة لكل من KAGGLE_USERNAME وKAGGLE_KEY عندما يُطلب منك عرض رسالة "هل تريد منح إمكانية الوصول؟" أو الموافقة على توفير الوصول السري.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

تثبيت مكتبة recurrentgemma

تركّز ورقة الملاحظات هذه على استخدام وحدة معالجة رسومات مجانية في Colab. لتفعيل ميزة "تسريع الأجهزة"، انقر على تعديل >. إعدادات ورقة الملاحظات > اختَر وحدة معالجة الرسومات T4 > انقر على حفظ.

بعد ذلك، عليك تثبيت مكتبة Google DeepMind recurrentgemma من github.com/google-deepmind/recurrentgemma. إذا ظهرت لك رسالة خطأ بشأن "أداة حل التبعية في pip"، يمكنك عادةً تجاهلها.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

تحميل نموذج RecurrentGemma وإعداده

  1. حمِّل نموذج RecurrentGemma باستخدام السمة kagglehub.model_download، وهي ثلاث وسيطات:
  • handle: مؤشر النموذج من Kaggle
  • path: (سلسلة اختيارية) المسار المحلي
  • force_download: (قيمة منطقية اختيارية) تفرض إعادة تنزيل النموذج
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. تحقَّق من الموقع الجغرافي لقيم ترجيح النموذج وأداة إنشاء الرموز المميّزة، ثم اضبط متغيّرات المسار. سيكون دليل أداة إنشاء الرموز المميّزة في الدليل الرئيسي الذي نزّلت فيه النموذج، بينما ستكون القيم التقديرية للنموذج في دليل فرعي. على سبيل المثال:
  • سيكون الملف tokenizer.model باللغة /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1).
  • ستكون النقطة المرجعية للنموذج في /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

إجراء أخذ العينات/الاستنتاج

  1. حمِّل نقطة مراجعة نموذج RecurrentGemma باستخدام الطريقة recurrentgemma.jax.load_parameters. تُحمِّل الوسيطة sharding المضبوطة على "single_device" جميع مَعلمات النموذج على جهاز واحد.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. تحميل أداة إنشاء الرموز المميّزة لنموذج RecurrentGemma الذي تم إنشاؤه باستخدام sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. لتحميل الإعدادات الصحيحة تلقائيًا من نقطة تفتيش نموذج RecurrentGemma، استخدِم recurrentgemma.GriffinConfig.from_flax_params_or_variables. بعد ذلك، يمكنك إنشاء مثيل لنموذج Griffin باستخدام recurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. أنشِئ sampler باستخدام recurrentgemma.jax.Sampler أعلى نقطة فحص نموذج RecurrentGemma وأداة إنشاء الرموز المميّزة:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. اكتب طلبًا باللغة prompt ونفِّذ الاستنتاج. يمكنك تعديل total_generation_steps (عدد الخطوات التي يتم تنفيذها عند إنشاء ردّ، إذ يستخدم هذا المثال 50 للاحتفاظ بذاكرة المضيف).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

مزيد من المعلومات