เรียกใช้ Gemma โดยใช้ PyTorch

ดูที่ ai.google.dev เรียกใช้ใน Google Colab ดูแหล่งข้อมูลใน GitHub

คู่มือนี้จะแสดงวิธีเรียกใช้ Gemma โดยใช้เฟรมเวิร์ก PyTorch รวมถึงวิธี ใช้ข้อมูลรูปภาพเพื่อแจ้งให้โมเดล Gemma รุ่น 3 ขึ้นไป ดูรายละเอียดเพิ่มเติมเกี่ยวกับการติดตั้งใช้งาน Gemma PyTorch ได้ที่ที่เก็บโปรเจ็กต์ README

ตั้งค่า

ส่วนต่อไปนี้จะอธิบายวิธีตั้งค่าสภาพแวดล้อมการพัฒนา รวมถึงวิธีรับสิทธิ์เข้าถึงโมเดล Gemma เพื่อดาวน์โหลดจาก Kaggle การตั้งค่า ตัวแปรการตรวจสอบสิทธิ์ การติดตั้งการอ้างอิง และการนำเข้าแพ็กเกจ

ข้อกำหนดของระบบ

ไลบรารี Gemma Pytorch นี้ต้องใช้โปรเซสเซอร์ GPU หรือ TPU เพื่อเรียกใช้โมเดล Gemma รันไทม์ Python ของ CPU ใน Colab แบบมาตรฐานและรันไทม์ Python ของ T4 GPU เพียงพอต่อการเรียกใช้โมเดลขนาด 1B, 2B และ 4B ของ Gemma สำหรับกรณีการใช้งานขั้นสูง สำหรับ GPU หรือ TPU อื่นๆ โปรดดู README ใน ที่เก็บ Gemma PyTorch

รับสิทธิ์เข้าถึง Gemma ใน Kaggle

หากต้องการทำตามบทแนะนำนี้ให้เสร็จสมบูรณ์ คุณต้องทำตามวิธีการตั้งค่าที่ การตั้งค่า Gemma ก่อน ซึ่งจะแสดงวิธีทำสิ่งต่อไปนี้

  • รับสิทธิ์เข้าถึง Gemma ใน Kaggle
  • เลือกรันไทม์ Colab ที่มีทรัพยากรเพียงพอที่จะเรียกใช้โมเดล Gemma
  • สร้างและกำหนดค่าชื่อผู้ใช้และคีย์ API ของ Kaggle

หลังจากตั้งค่า Gemma เสร็จแล้ว ให้ไปที่ส่วนถัดไปเพื่อตั้งค่าตัวแปรสภาพแวดล้อมสำหรับสภาพแวดล้อม Colab

ตั้งค่าตัวแปรสภาพแวดล้อม

ตั้งค่าตัวแปรสภาพแวดล้อมสำหรับ KAGGLE_USERNAME และ KAGGLE_KEY เมื่อได้รับแจ้ง ด้วยข้อความ "ให้สิทธิ์เข้าถึงไหม" ให้ยอมรับที่จะให้สิทธิ์เข้าถึงข้อมูลลับ

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

ติดตั้งการอ้างอิง

pip install -q -U torch immutabledict sentencepiece

ดาวน์โหลดน้ำหนักของโมเดล

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT.split('-')[0]
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

ตั้งค่าเส้นทางของโทเค็นไนเซอร์และจุดตรวจสอบสำหรับโมเดล

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

กำหนดค่าสภาพแวดล้อมการเรียกใช้

ส่วนต่อไปนี้จะอธิบายวิธีเตรียมสภาพแวดล้อม PyTorch เพื่อเรียกใช้ Gemma

เตรียมสภาพแวดล้อมการเรียกใช้ PyTorch

เตรียมสภาพแวดล้อมการเรียกใช้โมเดล PyTorch โดยการโคลนที่เก็บ Gemma Pytorch

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

ตั้งค่ารูปแบบ

ก่อนที่จะเรียกใช้โมเดล คุณต้องตั้งค่าพารามิเตอร์การกำหนดค่าบางอย่าง ซึ่งรวมถึง ตัวแปร Gemma, ตัวแยกคำ และระดับการวัดปริมาณ

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

กำหนดค่าบริบทของอุปกรณ์

โค้ดต่อไปนี้จะกำหนดค่าบริบทของอุปกรณ์สำหรับการเรียกใช้โมเดล

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

สร้างอินสแตนซ์และโหลดโมเดล

โหลดโมเดลพร้อมน้ำหนักเพื่อเตรียมพร้อมที่จะเรียกใช้คำขอ

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

เรียกใช้การอนุมาน

ด้านล่างนี้คือตัวอย่างการสร้างในโหมดแชทและการสร้างด้วยคำขอหลายรายการ

โมเดล Gemma ที่ได้รับการปรับแต่งตามคำสั่งได้รับการฝึกด้วยโปรแกรมจัดรูปแบบเฉพาะที่ ใส่คำอธิบายประกอบตัวอย่างการปรับแต่งตามคำสั่งด้วยข้อมูลเพิ่มเติม ทั้งในระหว่าง การฝึกและอนุมาน คำอธิบายประกอบ (1) จะระบุบทบาทในการสนทนา และ (2) จะระบุลำดับการสนทนา

โทเค็นคำอธิบายประกอบที่เกี่ยวข้องมีดังนี้

  • user: เทิร์นของผู้ใช้
  • model: รอบของโมเดล
  • <start_of_turn>: จุดเริ่มต้นของรอบการสนทนา
  • <start_of_image>: แท็กสำหรับการป้อนข้อมูลรูปภาพ
  • <end_of_turn><eos>: สิ้นสุดการโต้ตอบ

ดูข้อมูลเพิ่มเติมได้ที่หัวข้อการจัดรูปแบบพรอมต์สำหรับโมเดล Gemma ที่ปรับแต่งตามคำสั่งที่นี่

สร้างข้อความด้วยข้อความ

ตัวอย่างต่อไปนี้คือตัวอย่างโค้ดที่แสดงวิธีจัดรูปแบบพรอมต์สำหรับ โมเดล Gemma ที่ปรับแต่งตามคำสั่งโดยใช้เทมเพลตแชทของผู้ใช้และโมเดลใน การสนทนาแบบหลายรอบ

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

สร้างข้อความพร้อมรูปภาพ

ตั้งแต่ Gemma รุ่น 3 เป็นต้นไป คุณจะใช้รูปภาพกับพรอมต์ได้ ตัวอย่างต่อไปนี้แสดงวิธีรวมข้อมูลภาพกับพรอมต์

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image = read_image(
    'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
)

print(model.generate(
    [
        [
            '<start_of_turn>user\n',
            image,
            'What animal is in this image?<end_of_turn>\n',
            '<start_of_turn>model\n'
        ]
    ],
    device=device,
    output_len=256,
))

ดูข้อมูลเพิ่มเติม

ตอนนี้คุณได้เรียนรู้วิธีใช้ Gemma ใน Pytorch แล้ว คุณสามารถสำรวจสิ่งอื่นๆ อีกมากมายที่ Gemma ทำได้ใน ai.google.dev/gemma

ดูแหล่งข้อมูลอื่นๆ ที่เกี่ยวข้องต่อไปนี้ด้วย