Gemma ใน PyTorch

ดูใน ai.google.dev เรียกใช้ใน Google Colab ดูแหล่งที่มาใน GitHub

นี่คือการสาธิตสั้นๆ ของการเรียกใช้การอนุมาน Gemma ใน PyTorch ดูรายละเอียดเพิ่มเติมได้ที่ที่เก็บ GitHub ของการใช้งาน PyTorch อย่างเป็นทางการที่นี่

โปรดทราบว่า

  • รันไทม์ Python สำหรับ CPU ของ Colab และรันไทม์ Python สำหรับ GPU รุ่น T4 แบบไม่มีค่าใช้จ่ายเพียงพอที่จะเรียกใช้โมเดล Gemma 2B และโมเดลที่แปลงค่าเป็น int8 7B
  • สำหรับกรณีการใช้งานขั้นสูงของ GPU หรือ TPU อื่นๆ โปรดดู README.md ในรีโปอย่างเป็นทางการ

1. ตั้งค่าการเข้าถึง Kaggle สําหรับ Gemma

หากต้องการจบบทแนะนำนี้ ก่อนอื่นคุณต้องทำตามวิธีการตั้งค่าที่การตั้งค่า Gemma ซึ่งแสดงวิธีดำเนินการต่อไปนี้

  • รับสิทธิ์เข้าถึง Gemma ใน kaggle.com
  • เลือกรันไทม์ Colab ที่มีทรัพยากรเพียงพอที่จะเรียกใช้โมเดล Gemma
  • สร้างและกำหนดค่าชื่อผู้ใช้และคีย์ API ของ Kaggle

หลังจากตั้งค่า Gemma เสร็จแล้ว ให้ไปยังส่วนถัดไปเพื่อตั้งค่าตัวแปรสภาพแวดล้อมสําหรับสภาพแวดล้อม Colab

2. ตั้งค่าตัวแปรสภาพแวดล้อม

ตั้งค่าตัวแปรสภาพแวดล้อมสําหรับ KAGGLE_USERNAME และ KAGGLE_KEY เมื่อได้รับข้อความ "ให้สิทธิ์เข้าถึงไหม" ให้ยอมรับการให้สิทธิ์เข้าถึงข้อมูลลับ

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

ติดตั้งการอ้างอิง

pip install -q -U torch immutabledict sentencepiece

ดาวน์โหลดน้ำหนักโมเดล

# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'

CONFIG = VARIANT[:2]
if CONFIG == '2b':
  CONFIG = '2b-v2'
import os
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-2/pyTorch/gemma-2-{VARIANT}')
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

ดาวน์โหลดการใช้งานโมเดล

# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch')
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

ตั้งค่าโมเดล

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

เรียกใช้การอนุมาน

ด้านล่างนี้คือตัวอย่างการสร้างในโหมดแชทและการสร้างด้วยคำขอหลายรายการ

โมเดล Gemma ที่ปรับแต่งคำสั่งได้รับการฝึกด้วยเครื่องมือจัดรูปแบบที่เฉพาะเจาะจงซึ่งกำกับเนื้อหาตัวอย่างการปรับแต่งคำสั่งด้วยข้อมูลเพิ่มเติม ทั้งในระหว่างการฝึกและการทำนาย คําอธิบายประกอบ (1) ระบุบทบาทในการสนทนา และ (2) ระบุการผลัดกันพูดในการสนทนา

โทเค็นคำอธิบายประกอบที่เกี่ยวข้อง ได้แก่

  • user: การเปลี่ยนผู้ใช้
  • model: เปลี่ยนรุ่น
  • <start_of_turn>: เริ่มต้นของรอบการสนทนา
  • <end_of_turn><eos>: สิ้นสุดรอบการสนทนา

อ่านข้อมูลเพิ่มเติมได้ที่การจัดรูปแบบข้อความแจ้งสำหรับวิธีการปรับแต่งโมเดล Gemma ได้ที่นี่

ต่อไปนี้คือตัวอย่างข้อมูลโค้ดที่แสดงวิธีจัดรูปแบบพรอมต์สำหรับโมเดล Gemma ที่ปรับตามวิธีการโดยใช้เทมเพลตแชทของผู้ใช้และโมเดลในการสนทนาแบบหลายรอบ

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=128,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

ดูข้อมูลเพิ่มเติม

เมื่อได้ทราบวิธีใช้ Gemma ใน Pytorch แล้ว คุณสามารถสำรวจสิ่งอื่นๆ อีกมากมายที่ Gemma ทำได้ที่ ai.google.dev/gemma โปรดดูแหล่งข้อมูลอื่นๆ ที่เกี่ยวข้องเหล่านี้ด้วย