เริ่มต้นใช้งาน Gemma โดยใช้ KerasNLP

ดูใน ai.google.dev เรียกใช้ใน Google Colab เปิดใน Vertex AI ดูแหล่งที่มาใน GitHub

บทแนะนำนี้จะแสดงวิธีเริ่มต้นใช้งาน Gemma โดยใช้ KerasNLP Gemma เป็นชุดโมเดลเปิดที่ทันสมัยและน้ำหนักเบา สร้างขึ้นจากการวิจัยและเทคโนโลยีเดียวกันกับที่ใช้ในการสร้างโมเดล Gemini KerasNLP คือคอลเล็กชันของโมเดลการประมวลผลภาษาธรรมชาติ (NLP) ที่ใช้งานใน Keras และเรียกใช้บน JAX, PyTorch และ TensorFlow

ในบทแนะนำนี้ คุณจะใช้ Gemma เพื่อสร้างข้อความตอบกลับสำหรับพรอมต์ต่างๆ หากคุณเพิ่งเริ่มใช้ Keras คุณควรอ่านการเริ่มต้นใช้งาน Keras ก่อนเริ่มต้น แต่ก็ไม่จำเป็น คุณดูข้อมูลเพิ่มเติมเกี่ยวกับ Keras ได้เมื่อทำตามบทแนะนำนี้

ตั้งค่า

การตั้งค่า Gemma

หากต้องการจบบทแนะนำนี้ คุณจะต้องทำตามวิธีการตั้งค่าที่การตั้งค่า Gemma ก่อน วิธีการตั้งค่า Gemma จะแสดงวิธีดำเนินการต่อไปนี้

  • เข้าถึง Gemma ใน kaggle.com
  • เลือกรันไทม์ของ Colab ที่มีทรัพยากรเพียงพอที่จะเรียกใช้ โมเดล Gemma 2B
  • สร้างและกำหนดค่าชื่อผู้ใช้และคีย์ API ของ Kaggle

หลังจากตั้งค่า Gemma เสร็จแล้ว ให้ไปยังส่วนถัดไปซึ่งจะตั้งค่าตัวแปรสภาพแวดล้อมสำหรับสภาพแวดล้อม Colab

ตั้งค่าตัวแปรสภาพแวดล้อม

ตั้งค่าตัวแปรสภาพแวดล้อมสำหรับ KAGGLE_USERNAME และ KAGGLE_KEY

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

ติดตั้งการอ้างอิง

ติดตั้ง Keras และ KerasNLP

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

เลือกแบ็กเอนด์

Keras เป็น API การเรียนรู้เชิงลึกที่มีหลายกรอบและมีระดับสูงซึ่งออกแบบมาให้ใช้งานง่าย Keras 3 ให้คุณเลือกแบ็กเอนด์ ได้แก่ TensorFlow, JAX หรือ PyTorch ทั้ง 3 วิธีนี้เหมาะกับบทแนะนำนี้

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

นำเข้าแพ็กเกจ

นำเข้า Keras และ KerasNLP

import keras
import keras_nlp

สร้างโมเดล

KerasNLP ให้บริการนำสถาปัตยกรรมโมเดลยอดนิยมจำนวนมากมาใช้ ในบทแนะนำนี้ คุณจะได้สร้างโมเดลโดยใช้ GemmaCausalLM ซึ่งเป็นโมเดล Gemma จากต้นทางถึงปลายทางสำหรับการสร้างโมเดลภาษาทั่วไป โมเดลภาษาทั่วไปจะคาดการณ์โทเค็นถัดไปตามโทเค็นก่อนหน้า

สร้างโมเดลโดยใช้เมธอด from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

ฟังก์ชัน GemmaCausalLM.from_preset() จะสร้างอินสแตนซ์โมเดลจากสถาปัตยกรรมและน้ำหนักที่กำหนดไว้ล่วงหน้า ในโค้ดด้านบน สตริง "gemma2_2b_en" ระบุค่าที่กำหนดล่วงหน้าของโมเดล Gemma 2 2B ที่มีพารามิเตอร์ 2 พันล้านรายการ นอกจากนี้ยังมีโมเดล Gemma ที่มีพารามิเตอร์ 7B, 9B และ 27B ให้ใช้งานด้วย คุณดูสตริงโค้ดสำหรับโมเดล Gemma ได้ในรูปแบบโมเดลใน Kaggle

ใช้ summary เพื่อดูข้อมูลเพิ่มเติมเกี่ยวกับโมเดล

gemma_lm.summary()

จากข้อมูลสรุปจะเห็นได้ว่าโมเดลนี้มีพารามิเตอร์ที่ฝึกได้ 2.6 พันล้านรายการ

สร้างข้อความ

ตอนนี้ถึงเวลาสร้างข้อความแล้ว โมเดลมีเมธอด generate ที่สร้างข้อความตามพรอมต์ อาร์กิวเมนต์ max_length ที่ไม่บังคับจะระบุความยาวสูงสุดของลำดับที่สร้างขึ้น

ลองใช้งานด้วยพรอมต์ "what is keras in 3 bullet points?"

gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)
'what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n'

ลองโทรหา generate อีกครั้งโดยใช้ข้อความแจ้งอื่น

gemma_lm.generate("The universe is", max_length=64)
'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now'

หากกำลังทำงานบนแบ็กเอนด์ JAX หรือ TensorFlow คุณจะสังเกตเห็นว่าการเรียก generate ครั้งที่ 2 กลับมาเรียกใช้ได้แทบจะในทันที ที่เป็นเช่นนี้เนื่องจากการเรียกไปยัง generate แต่ละครั้งสำหรับขนาดกลุ่มที่ระบุและ max_length คอมไพล์ด้วย XLA การเรียกใช้ครั้งแรกมีราคาสูง แต่การเรียกใช้ครั้งต่อๆ ไปจะเร็วกว่ามาก

นอกจากนี้คุณยังส่งพรอมต์แบบกลุ่มโดยใช้รายการเป็นอินพุตได้ด้วย ดังนี้

gemma_lm.generate(
    ["what is keras in 3 bullet points?",
     "The universe is"],
    max_length=64)
['what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n',
 'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now']

ไม่บังคับ: ลองใช้ตัวอย่างอื่น

คุณสามารถควบคุมกลยุทธ์การสร้างสำหรับ GemmaCausalLM โดยการตั้งค่าอาร์กิวเมนต์ sampler ใน compile() โดยค่าเริ่มต้น ระบบจะใช้การสุ่มตัวอย่าง "greedy"

ในการทดสอบ ให้ลองตั้งค่ากลยุทธ์ "top_k" ดังนี้

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("The universe is", max_length=64)
'The universe is a big place, and there are so many things we do not know or understand about it.\n\nBut we can learn a lot about our world by studying what is known to us.\n\nFor example, if you look at the moon, it has many features that can be seen from the surface.'

แม้ว่าอัลกอริทึมโลภเริ่มต้นจะเลือกโทเค็นที่มีความน่าจะเป็นมากที่สุดเสมอ แต่อัลกอริทึมสูงสุด K จะสุ่มเลือกโทเค็นถัดไปจากโทเค็นของความน่าจะเป็นอันดับต้นๆ K

คุณไม่จำเป็นต้องระบุตัวอย่าง และไม่สนใจข้อมูลโค้ดล่าสุดได้หากไม่เป็นประโยชน์กับกรณีการใช้งานของคุณ ดูข้อมูลเพิ่มเติมเกี่ยวกับตัวอย่างวิดีโอที่พร้อมใช้งานได้ที่วิดีโอตัวอย่าง

ขั้นตอนถัดไป

ในบทแนะนำนี้ คุณได้เรียนรู้วิธีสร้างข้อความโดยใช้ KerasNLP และ Gemma คำแนะนำบางส่วนเกี่ยวกับสิ่งที่ควรเรียนรู้ต่อไปมีดังนี้