การอนุมานด้วย RecurrentGemma โดยใช้ JAX และ Flax

ดูใน ai.google.dev เรียกใช้ใน Google Colab เปิดใน Vertex AI ดูแหล่งที่มาใน GitHub

บทแนะนำนี้สาธิตวิธีการสุ่มตัวอย่าง/การอนุมานพื้นฐานด้วยโมเดล RecurrentGemma 2B Instruct โดยใช้ไลบรารีrecurrentgemmaของ Google DeepMind ที่เขียนด้วย JAX (ไลบรารีการประมวลผลตัวเลขที่มีประสิทธิภาพสูง), Flax (ไลบรารีเครือข่ายระบบประสาทแบบ JAX), Orbax (ไลบรารีของ JAX/Pcepointing} {Senttodeken1} ที่ใช้ฝึก+ตรวจสอบ) และยูทิลิตีการฝึกต่างๆ เช่น เครื่องมือตรวจสอบแบบ JAX1}SentencePiece แม้ว่าไม่ได้ใช้ Flax ในสมุดบันทึกนี้โดยตรง แต่ใช้ Flax เพื่อสร้าง Gemma และ RecurrentGemma (โมเดลกริฟฟิน)

สมุดบันทึกนี้ทำงานบน Google Colab ได้ด้วย GPU รุ่น T4 (ไปที่แก้ไข > การตั้งค่าสมุดบันทึก > ใต้ตัวเร่งฮาร์ดแวร์ ให้เลือก GPU T4)

ตั้งค่า

ส่วนต่อไปนี้จะอธิบายขั้นตอนในการเตรียมสมุดบันทึกเพื่อใช้โมเดล RecurrentGemma รวมถึงการเข้าถึงโมเดล การรับคีย์ API และการกำหนดค่ารันไทม์ของสมุดบันทึก

ตั้งค่าการเข้าถึง Kaggle สำหรับ Gemma

หากต้องการจบบทแนะนำนี้ ก่อนอื่นคุณต้องทำตามวิธีการตั้งค่าคล้ายกับการตั้งค่า Gemma โดยมีข้อยกเว้นบางประการดังนี้

  • รับสิทธิ์เข้าถึง RecurrentGemma (แทน Gemma) ใน kaggle.com
  • เลือกรันไทม์ของ Colab ที่มีทรัพยากรเพียงพอที่จะเรียกใช้โมเดล RecurrentGemma
  • สร้างและกำหนดค่าชื่อผู้ใช้และคีย์ API ของ Kaggle

หลังจากตั้งค่า RecurrentGemma เรียบร้อยแล้ว ให้ข้ามไปที่ส่วนถัดไปซึ่งเป็นชุดตัวแปรสภาพแวดล้อมสำหรับสภาพแวดล้อม Colab

ตั้งค่าตัวแปรสภาพแวดล้อม

ตั้งค่าตัวแปรสภาพแวดล้อมสำหรับ KAGGLE_USERNAME และ KAGGLE_KEY เมื่อมีข้อความแจ้งด้วยข้อความ "ให้สิทธิ์เข้าถึงไหม" ยอมรับการเข้าถึงข้อมูลลับ

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

ติดตั้งไลบรารี recurrentgemma

สมุดบันทึกนี้มุ่งเน้นที่การใช้ Colab GPU ฟรี หากต้องการเปิดใช้การเร่งฮาร์ดแวร์ ให้คลิกแก้ไข > การตั้งค่าสมุดบันทึก > เลือก T4 GPU > บันทึก

ถัดไป คุณต้องติดตั้งไลบรารี Google DeepMind recurrentgemma จาก github.com/google-deepmind/recurrentgemma หากคุณได้รับข้อผิดพลาดเกี่ยวกับ "รีโซลเวอร์ทรัพยากร Dependency ของ PIP" โดยปกติแล้วคุณไม่ต้องสนใจ

pip install git+https://github.com/google-deepmind/recurrentgemma.git

โหลดและเตรียมโมเดล RecurrentGemma

  1. โหลดโมเดล RecurrentGemma ด้วย kagglehub.model_download ซึ่งมีอาร์กิวเมนต์ 3 ตัว ดังนี้
  • handle: แฮนเดิลโมเดลจาก Kaggle
  • path: (สตริงที่ไม่บังคับ) เส้นทางในเครื่อง
  • force_download: (บูลีนที่ไม่บังคับ) บังคับให้ดาวน์โหลดโมเดลอีกครั้ง
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. ตรวจสอบตำแหน่งของน้ำหนักโมเดลและเครื่องมือแปลงข้อมูลเป็นโทเค็น จากนั้นตั้งค่าตัวแปรเส้นทาง ไดเรกทอรีโทเคนไลซ์จะอยู่ในไดเรกทอรีหลักที่คุณดาวน์โหลดโมเดลไป ขณะที่น้ำหนักโมเดลจะอยู่ในไดเรกทอรีย่อย เช่น
  • ไฟล์ tokenizer.model จะอยู่ใน /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1)
  • จุดตรวจสอบโมเดลจะอยู่ใน /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it)
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

ทำการสุ่มตัวอย่าง/การอนุมาน

  1. โหลดจุดเช็คพอยท์ของโมเดล RecurrentGemma ด้วยเมธอด recurrentgemma.jax.load_parameters อาร์กิวเมนต์ sharding ที่ตั้งค่าเป็น "single_device" จะโหลดพารามิเตอร์โมเดลทั้งหมดในอุปกรณ์เดียว
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. โหลดเครื่องมือแปลงโทเค็นโมเดล RecurrentGemma ซึ่งสร้างขึ้นโดยใช้ sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. หากต้องการโหลดการกำหนดค่าที่ถูกต้องจากจุดตรวจสอบโมเดล RecurrentGemma โดยอัตโนมัติ ให้ใช้ recurrentgemma.GriffinConfig.from_flax_params_or_variables จากนั้นสร้างอินสแตนซ์โมเดล Griffin ด้วย recurrentgemma.jax.Griffin
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. สร้าง sampler ด้วย recurrentgemma.jax.Sampler เพิ่มเติมจากจุดตรวจสอบ/น้ำหนักของโมเดล RecurrentGemma และเครื่องมือแปลงข้อมูลเป็นโทเค็น ดังนี้
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. เขียนพรอมต์ใน prompt และดำเนินการอนุมาน คุณสามารถปรับแต่ง total_generation_steps (จำนวนขั้นตอนที่ดำเนินการเมื่อสร้างคำตอบ โดยตัวอย่างนี้ใช้ 50 เพื่อเก็บรักษาหน่วยความจำของโฮสต์)
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

ดูข้อมูลเพิ่มเติม