ดูใน ai.google.dev | เรียกใช้ใน Google Colab | เปิดใน Vertex AI | ดูแหล่งที่มาใน GitHub |
บทแนะนำนี้สาธิตวิธีการสุ่มตัวอย่าง/การอนุมานพื้นฐานด้วยโมเดล RecurrentGemma 2B Instruct โดยใช้ไลบรารีrecurrentgemma
ของ Google DeepMind ที่เขียนด้วย JAX (ไลบรารีการประมวลผลตัวเลขที่มีประสิทธิภาพสูง), Flax (ไลบรารีเครือข่ายระบบประสาทแบบ JAX), Orbax (ไลบรารีของ JAX/Pcepointing} {Senttodeken1} ที่ใช้ฝึก+ตรวจสอบ) และยูทิลิตีการฝึกต่างๆ เช่น เครื่องมือตรวจสอบแบบ JAX1}SentencePiece แม้ว่าไม่ได้ใช้ Flax ในสมุดบันทึกนี้โดยตรง แต่ใช้ Flax เพื่อสร้าง Gemma และ RecurrentGemma (โมเดลกริฟฟิน)
สมุดบันทึกนี้ทำงานบน Google Colab ได้ด้วย GPU รุ่น T4 (ไปที่แก้ไข > การตั้งค่าสมุดบันทึก > ใต้ตัวเร่งฮาร์ดแวร์ ให้เลือก GPU T4)
ตั้งค่า
ส่วนต่อไปนี้จะอธิบายขั้นตอนในการเตรียมสมุดบันทึกเพื่อใช้โมเดล RecurrentGemma รวมถึงการเข้าถึงโมเดล การรับคีย์ API และการกำหนดค่ารันไทม์ของสมุดบันทึก
ตั้งค่าการเข้าถึง Kaggle สำหรับ Gemma
หากต้องการจบบทแนะนำนี้ ก่อนอื่นคุณต้องทำตามวิธีการตั้งค่าคล้ายกับการตั้งค่า Gemma โดยมีข้อยกเว้นบางประการดังนี้
- รับสิทธิ์เข้าถึง RecurrentGemma (แทน Gemma) ใน kaggle.com
- เลือกรันไทม์ของ Colab ที่มีทรัพยากรเพียงพอที่จะเรียกใช้โมเดล RecurrentGemma
- สร้างและกำหนดค่าชื่อผู้ใช้และคีย์ API ของ Kaggle
หลังจากตั้งค่า RecurrentGemma เรียบร้อยแล้ว ให้ข้ามไปที่ส่วนถัดไปซึ่งเป็นชุดตัวแปรสภาพแวดล้อมสำหรับสภาพแวดล้อม Colab
ตั้งค่าตัวแปรสภาพแวดล้อม
ตั้งค่าตัวแปรสภาพแวดล้อมสำหรับ KAGGLE_USERNAME
และ KAGGLE_KEY
เมื่อมีข้อความแจ้งด้วยข้อความ "ให้สิทธิ์เข้าถึงไหม" ยอมรับการเข้าถึงข้อมูลลับ
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
ติดตั้งไลบรารี recurrentgemma
สมุดบันทึกนี้มุ่งเน้นที่การใช้ Colab GPU ฟรี หากต้องการเปิดใช้การเร่งฮาร์ดแวร์ ให้คลิกแก้ไข > การตั้งค่าสมุดบันทึก > เลือก T4 GPU > บันทึก
ถัดไป คุณต้องติดตั้งไลบรารี Google DeepMind recurrentgemma
จาก github.com/google-deepmind/recurrentgemma
หากคุณได้รับข้อผิดพลาดเกี่ยวกับ "รีโซลเวอร์ทรัพยากร Dependency ของ PIP" โดยปกติแล้วคุณไม่ต้องสนใจ
pip install git+https://github.com/google-deepmind/recurrentgemma.git
โหลดและเตรียมโมเดล RecurrentGemma
- โหลดโมเดล RecurrentGemma ด้วย
kagglehub.model_download
ซึ่งมีอาร์กิวเมนต์ 3 ตัว ดังนี้
handle
: แฮนเดิลโมเดลจาก Kagglepath
: (สตริงที่ไม่บังคับ) เส้นทางในเครื่องforce_download
: (บูลีนที่ไม่บังคับ) บังคับให้ดาวน์โหลดโมเดลอีกครั้ง
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- ตรวจสอบตำแหน่งของน้ำหนักโมเดลและเครื่องมือแปลงข้อมูลเป็นโทเค็น จากนั้นตั้งค่าตัวแปรเส้นทาง ไดเรกทอรีโทเคนไลซ์จะอยู่ในไดเรกทอรีหลักที่คุณดาวน์โหลดโมเดลไป ขณะที่น้ำหนักโมเดลจะอยู่ในไดเรกทอรีย่อย เช่น
- ไฟล์
tokenizer.model
จะอยู่ใน/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
) - จุดตรวจสอบโมเดลจะอยู่ใน
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
)
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
ทำการสุ่มตัวอย่าง/การอนุมาน
- โหลดจุดเช็คพอยท์ของโมเดล RecurrentGemma ด้วยเมธอด
recurrentgemma.jax.load_parameters
อาร์กิวเมนต์sharding
ที่ตั้งค่าเป็น"single_device"
จะโหลดพารามิเตอร์โมเดลทั้งหมดในอุปกรณ์เดียว
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- โหลดเครื่องมือแปลงโทเค็นโมเดล RecurrentGemma ซึ่งสร้างขึ้นโดยใช้
sentencepiece.SentencePieceProcessor
:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- หากต้องการโหลดการกำหนดค่าที่ถูกต้องจากจุดตรวจสอบโมเดล RecurrentGemma โดยอัตโนมัติ ให้ใช้
recurrentgemma.GriffinConfig.from_flax_params_or_variables
จากนั้นสร้างอินสแตนซ์โมเดล Griffin ด้วยrecurrentgemma.jax.Griffin
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- สร้าง
sampler
ด้วยrecurrentgemma.jax.Sampler
เพิ่มเติมจากจุดตรวจสอบ/น้ำหนักของโมเดล RecurrentGemma และเครื่องมือแปลงข้อมูลเป็นโทเค็น ดังนี้
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- เขียนพรอมต์ใน
prompt
และดำเนินการอนุมาน คุณสามารถปรับแต่งtotal_generation_steps
(จำนวนขั้นตอนที่ดำเนินการเมื่อสร้างคำตอบ โดยตัวอย่างนี้ใช้50
เพื่อเก็บรักษาหน่วยความจำของโฮสต์)
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
ดูข้อมูลเพิ่มเติม
- คุณดูข้อมูลเพิ่มเติมเกี่ยวกับไลบรารี Google DeepMind
recurrentgemma
ใน GitHub ได้ ซึ่งมีชุดเอกสารของเมธอดและโมดูลที่คุณใช้ในบทแนะนำนี้ เช่นrecurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
และrecurrentgemma.jax.Sampler
- ไลบรารีต่อไปนี้มีเว็บไซต์เอกสารประกอบของตนเอง ได้แก่ JAX หลัก, Flax และ Orbax
- ดูเอกสารประกอบเกี่ยวกับเครื่องมือแปลงข้อมูลเป็นโทเค็น/เครื่องมือถอดรหัสของ
sentencepiece
ได้ที่ที่เก็บsentencepiece
GitHub ของ Google - ดูเอกสารประกอบเกี่ยวกับ
kagglehub
ได้ที่README.md
ในที่เก็บ GitHub ของkagglehub
ของ Kaggle - ดูวิธีใช้โมเดล Gemma กับ Vertex AI ของ Google Cloud
- รับชม RecurrentGemma: Move Through Transformers บทความเกี่ยวกับ Efficient Open Language Models โดย Google DeepMind
- อ่าน Griffin: Mixing Gated Linear Recurrences with บทความ Local Attention for Efficient Language Models โดย GoogleDeepMind เพื่อดูข้อมูลเพิ่มเติมเกี่ยวกับสถาปัตยกรรมโมเดลที่ RecurrentGemma ใช้