เริ่มต้นใช้งาน Gemma โดยใช้ KerasNLP

บทแนะนำนี้จะแสดงวิธีเริ่มต้นใช้งาน Gemma โดยใช้ KerasNLP Gemma เป็นกลุ่มผลิตภัณฑ์โมเดลแบบเปิดที่ทันสมัยและน้ำหนักเบาซึ่งสร้างขึ้นจากงานวิจัยและเทคโนโลยีเดียวกับที่ใช้สร้างโมเดล Gemini KerasNLP คือคอลเล็กชันโมเดลการประมวลผลภาษาธรรมชาติ (NLP) ที่ติดตั้งใช้งานใน Keras และเรียกใช้ได้ใน JAX, PyTorch และ TensorFlow

ในบทแนะนํานี้ คุณจะใช้ Gemma เพื่อสร้างคําตอบที่เป็นข้อความสําหรับพรอมต์หลายรายการ หากคุณเพิ่งเริ่มใช้ Keras คุณอาจต้องอ่านการเริ่มต้นใช้งาน Keras ก่อนเริ่มใช้งาน แต่ก็ไม่จําเป็น คุณจะได้เรียนรู้เพิ่มเติมเกี่ยวกับ Keras เมื่อทำตามบทแนะนำนี้

ตั้งค่า

การตั้งค่า Gemma

หากต้องการดูบทแนะนำนี้ให้จบ คุณจะต้องทําตามวิธีการตั้งค่าที่หัวข้อการตั้งค่า Gemma ก่อน วิธีการตั้งค่า Gemma จะแสดงวิธีดำเนินการต่อไปนี้

  • รับสิทธิ์เข้าถึง Gemma ใน kaggle.com
  • เลือกรันไทม์ Colab ที่มีทรัพยากรเพียงพอที่จะเรียกใช้โมเดล Gemma 2B
  • สร้างและกําหนดค่าชื่อผู้ใช้และคีย์ API ของ Kaggle

หลังจากตั้งค่า Gemma เสร็จแล้ว ให้ไปยังส่วนถัดไปเพื่อตั้งค่าตัวแปรสภาพแวดล้อมสําหรับสภาพแวดล้อม Colab

ตั้งค่าตัวแปรสภาพแวดล้อม

ตั้งค่าตัวแปรสภาพแวดล้อมสําหรับ KAGGLE_USERNAME และ KAGGLE_KEY

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

ติดตั้งการอ้างอิง

ติดตั้ง Keras และ KerasNLP

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

เลือกแบ็กเอนด์

Keras เป็น API การเรียนรู้เชิงลึกระดับสูงที่ใช้ได้กับหลายเฟรมเวิร์ก ซึ่งออกแบบมาเพื่อความเรียบง่ายและใช้งานง่าย Keras 3 ให้คุณเลือกแบ็กเอนด์ได้ ได้แก่ TensorFlow, JAX หรือ PyTorch ทั้ง 3 รายการนี้ใช้ได้กับบทแนะนำนี้

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

นำเข้าแพ็กเกจ

นําเข้า Keras และ KerasNLP

import keras
import keras_nlp

สร้างโมเดล

KerasNLP มีการใช้งานสถาปัตยกรรมโมเดลยอดนิยมมากมาย ในบทแนะนำนี้ คุณจะได้สร้างโมเดลโดยใช้ GemmaCausalLM ซึ่งเป็นโมเดล Gemma แบบครบวงจรสําหรับการประมาณภาษาเชิงสาเหตุ โมเดลภาษาเชิงสาเหตุจะคาดคะเนโทเค็นถัดไปตามโทเค็นก่อนหน้า

สร้างโมเดลโดยใช้เมธอด from_preset ดังนี้

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

ฟังก์ชัน GemmaCausalLM.from_preset() จะสร้างอินสแตนซ์ของโมเดลจากสถาปัตยกรรมและน้ำหนักที่กำหนดไว้ล่วงหน้า ในโค้ดด้านบน สตริง "gemma2_2b_en" จะระบุรูปแบบ Gemma 2 2B ที่กำหนดล่วงหน้าซึ่งมีพารามิเตอร์ 2, 000 ล้านรายการ นอกจากนี้ ยังมี Gemma รุ่นที่มีพารามิเตอร์ 7B, 9B และ 27B ด้วย คุณดูสตริงโค้ดสำหรับโมเดล Gemma ได้ในข้อมูลรูปแบบโมเดลบน Kaggle

ใช้ summary เพื่อดูข้อมูลเพิ่มเติมเกี่ยวกับรูปแบบ

gemma_lm.summary()

ดังที่เห็นจากข้อมูลสรุป โมเดลมีพารามิเตอร์ที่ฝึกได้ 2.6 พันล้านรายการ

สร้างข้อความ

ตอนนี้ถึงเวลาสร้างข้อความแล้ว โมเดลมีเมธอด generate ที่สร้างข้อความตามพรอมต์ อาร์กิวเมนต์ max_length (ไม่บังคับ) จะระบุความยาวสูงสุดของลําดับที่สร้างขึ้น

ลองใช้กับพรอมต์ "what is keras in 3 bullet points?"

gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)
'what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n'

ลองโทรหา generate อีกครั้งโดยใช้พรอมต์อื่น

gemma_lm.generate("The universe is", max_length=64)
'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now'

หากคุณใช้แบ็กเอนด์ JAX หรือ TensorFlow คุณจะเห็นว่าการเรียก generate ครั้งที่สองแสดงผลเกือบจะทันที เนื่องจากแต่ละการเรียก generate สำหรับขนาดการประมวลผลเป็นกลุ่มหนึ่งๆ และ max_length จะคอมไพล์ด้วย XLA การเรียกใช้ครั้งแรกมีค่าใช้จ่ายสูง แต่การเรียกใช้ครั้งต่อๆ ไปจะเร็วขึ้นมาก

นอกจากนี้ คุณยังระบุพรอมต์แบบเป็นกลุ่มโดยใช้รายการเป็นอินพุตได้ด้วย โดยทำดังนี้

gemma_lm.generate(
    ["what is keras in 3 bullet points?",
     "The universe is"],
    max_length=64)
['what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n',
 'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now']

ไม่บังคับ: ลองใช้เครื่องสุ่มตัวอย่างอื่น

คุณควบคุมกลยุทธ์การสร้างสําหรับ GemmaCausalLM ได้โดยการตั้งค่าอาร์กิวเมนต์ sampler ใน compile() โดยค่าเริ่มต้น ระบบจะใช้การสุ่มตัวอย่าง "greedy"

ลองตั้งค่ากลยุทธ์ "top_k" เป็นการทดสอบ โดยทำดังนี้

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("The universe is", max_length=64)
'The universe is a big place, and there are so many things we do not know or understand about it.\n\nBut we can learn a lot about our world by studying what is known to us.\n\nFor example, if you look at the moon, it has many features that can be seen from the surface.'

ในขณะที่อัลกอริทึมการค้นหาแบบโลภเริ่มต้นจะเลือกโทเค็นที่มีความน่าจะเป็นมากที่สุดเสมอ แต่อัลกอริทึม Top-K จะสุ่มเลือกโทเค็นถัดไปจากโทเค็นที่มีความน่าจะเป็นสูงสุด K รายการ

คุณไม่จำเป็นต้องระบุเครื่องสุ่มตัวอย่าง และสามารถละเว้นข้อมูลโค้ดสุดท้ายได้หากไม่เป็นประโยชน์ต่อกรณีการใช้งาน หากต้องการดูข้อมูลเพิ่มเติมเกี่ยวกับเครื่องสุ่มตัวอย่างที่ใช้ได้ โปรดดูเครื่องสุ่มตัวอย่าง

ขั้นตอนถัดไป

ในบทแนะนํานี้ คุณได้เรียนรู้วิธีสร้างข้อความโดยใช้ KerasNLP และ Gemma ต่อไปนี้เป็นคำแนะนำเกี่ยวกับสิ่งที่ควรเรียนรู้ต่อไป