การอนุมานด้วย Gemma โดยใช้ JAX และ Flax

ดูใน ai.google.dev เรียกใช้ใน Google Colab เปิดใน Vertex AI ดูซอร์สบน GitHub

ภาพรวม

Gemma เป็นกลุ่มโมเดลภาษาแบบเปิดที่ทันสมัยและมีน้ำหนักเบา โดยอิงจากการวิจัยและเทคโนโลยีของ Google DeepMind Gemini บทแนะนำนี้จะสาธิตวิธีการสุ่มตัวอย่าง/อนุมานพื้นฐานด้วยโมเดล Gemma 2B Instruct โดยใช้ไลบรารี gemma ของ Google DeepMind ที่เขียนด้วย JAX (ไลบรารีการประมวลผลเชิงตัวเลขประสิทธิภาพสูง), Flaxizer (ไลบรารีเครือข่ายระบบประสาทที่ใช้ JAX), Orbax (ไลบรารีที่ใช้ JAX/Sentence1Pence1) และSentencePiece แม้ว่าจะไม่มีการใช้ Flax โดยตรงในสมุดบันทึกนี้ แต่มีการใช้ Flax ในการสร้าง Gemma

สมุดบันทึกนี้ทำงานบน Google Colab ได้โดยใช้ T4 GPU ที่ใช้งานฟรี (ไปที่แก้ไข > การตั้งค่าสมุดบันทึก > เลือก GPU T4 ในส่วนตัวเร่งฮาร์ดแวร์)

ตั้งค่า

1. ตั้งค่าการเข้าถึง Kaggle สำหรับ Gemma

หากต้องการทำให้บทแนะนำนี้จบลง ก่อนอื่นคุณต้องทำตามวิธีการตั้งค่าที่การตั้งค่า Gemma ซึ่งแสดงวิธีดำเนินการต่อไปนี้

  • เข้าถึง Gemma บน kaggle.com
  • เลือกรันไทม์ของ Colab ที่มีทรัพยากรเพียงพอที่จะเรียกใช้โมเดล Gemma
  • สร้างและกำหนดค่าชื่อผู้ใช้และคีย์ API ของ Kaggle

หลังจากตั้งค่า Gemma เรียบร้อยแล้ว ให้ไปยังส่วนถัดไปเพื่อตั้งค่าตัวแปรสภาพแวดล้อมสำหรับสภาพแวดล้อม Colab

2. ตั้งค่าตัวแปรสภาพแวดล้อม

ตั้งค่าตัวแปรสภาพแวดล้อมสำหรับ KAGGLE_USERNAME และ KAGGLE_KEY เมื่อได้รับข้อความแจ้งเกี่ยวกับ "อนุญาตให้เข้าถึงไหม" ให้ยอมรับที่จะให้สิทธิ์เข้าถึงข้อมูลลับ

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. ติดตั้งไลบรารี gemma

สมุดบันทึกนี้เน้นการใช้ Colab GPU ที่ไม่มีค่าใช้จ่าย หากต้องการเปิดใช้การเร่งฮาร์ดแวร์ ให้คลิกแก้ไข > การตั้งค่าสมุดบันทึก > เลือก GPU T4 > บันทึก

ถัดไป คุณต้องติดตั้งไลบรารี Google DeepMind gemma จาก github.com/google-deepmind/gemma หากได้รับข้อผิดพลาดเกี่ยวกับ "รีโซลเวอร์ Dependency ของ PIP" คุณก็ไม่ต้องสนใจข้อผิดพลาดนี้

pip install -q git+https://github.com/google-deepmind/gemma.git

โหลดและเตรียมโมเดล Gemma

  1. โหลดโมเดล Gemma ด้วย kagglehub.model_download ซึ่งจะใช้อาร์กิวเมนต์ 3 รายการดังนี้
  • handle: แฮนเดิลโมเดลจาก Kaggle
  • path: (สตริงที่ไม่บังคับ) เส้นทางภายใน
  • force_download: (บูลีนที่ไม่บังคับ) บังคับให้ดาวน์โหลดโมเดลอีกครั้ง
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. ตรวจสอบตำแหน่งของน้ำหนักโมเดลและตัวแปลงข้อมูลเป็นโทเค็น จากนั้นตั้งค่าตัวแปรเส้นทาง ไดเรกทอรี Tokenizer จะอยู่ในไดเรกทอรีหลักที่คุณดาวน์โหลดโมเดล ส่วนน้ำหนักของโมเดลจะอยู่ในไดเรกทอรีย่อย เช่น
  • ไฟล์ tokenizer.model จะอยู่ใน /LOCAL/PATH/TO/gemma/flax/2b-it/2)
  • จุดตรวจของโมเดลจะอยู่ใน /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it)
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

ทำการสุ่มตัวอย่าง/อนุมาน

  1. โหลดและจัดรูปแบบจุดตรวจสอบโมเดล Gemma ด้วยเมธอด gemma.params.load_and_format_params ดังนี้
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. โหลดเครื่องมือแปลงข้อมูลโทเค็น Gemma ที่สร้างขึ้นโดยใช้ sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. หากต้องการโหลดการกำหนดค่าที่ถูกต้องโดยอัตโนมัติจากจุดตรวจสอบโมเดล Gemma ให้ใช้ gemma.transformer.TransformerConfig อาร์กิวเมนต์ cache_size คือจำนวนขั้นตอนในแคชของ Gemma Transformer หลังจากนั้น ให้สร้างโมเดล Gemma เป็น transformer ด้วย gemma.transformer.Transformer (ซึ่งรับค่ามาจาก flax.linen.Module)
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. สร้าง sampler ด้วย gemma.sampler.Sampler ด้านบนของจุดตรวจสอบ/น้ำหนักของโมเดล Gemma และเครื่องมือแปลงข้อมูลเป็นโทเค็น ดังนี้
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. เขียนพรอมต์ใน input_batch และทำการอนุมาน คุณสามารถปรับแต่ง total_generation_steps (จำนวนขั้นตอนที่ดำเนินการเมื่อสร้างคำตอบ ตัวอย่างนี้ใช้ 100 เพื่อเก็บรักษาหน่วยความจำของโฮสต์)
prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. (ไม่บังคับ) เรียกใช้เซลล์นี้เพื่อเพิ่มหน่วยความจำหากสร้างสมุดบันทึกเสร็จเรียบร้อยแล้วและต้องการลองใช้พรอมต์อื่น หลังจากนั้น คุณสามารถสร้างอินสแตนซ์ sampler อีกครั้งในขั้นตอนที่ 3 แล้วปรับแต่งและเรียกใช้พรอมต์ในขั้นตอนที่ 4
del sampler

ดูข้อมูลเพิ่มเติม