Gemma dalam PyTorch

Lihat di ai.google.dev Jalankan di Google Colab Lihat sumber di GitHub

Ini adalah demo singkat tentang cara menjalankan inferensi Gemma di PyTorch. Untuk detail selengkapnya, lihat repo GitHub untuk implementasi PyTorch resmi di sini.

Perhatikan bahwa:

  • Runtime Python CPU Colab gratis dan runtime Python GPU T4 sudah cukup untuk menjalankan model Gemma 2B dan model kuantisasi int8 7B.
  • Untuk kasus penggunaan lanjutan dengan GPU atau TPU lain, lihat README.md di repo resmi.

1. Menyiapkan akses Kaggle untuk Gemma

Untuk menyelesaikan tutorial ini, Anda harus mengikuti petunjuk penyiapan di Penyiapan Gemma, yang menunjukkan cara melakukan hal berikut:

  • Dapatkan akses ke Gemma di kaggle.com.
  • Pilih runtime Colab dengan resource yang memadai untuk menjalankan model Gemma.
  • Buat dan konfigurasikan nama pengguna dan kunci API Kaggle.

Setelah menyelesaikan penyiapan Gemma, lanjutkan ke bagian berikutnya, tempat Anda akan menetapkan variabel lingkungan untuk lingkungan Colab.

2. Menetapkan variabel lingkungan

Tetapkan variabel lingkungan untuk KAGGLE_USERNAME dan KAGGLE_KEY. Saat diminta dengan pesan "Berikan akses?", setujui untuk memberikan akses secret.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Menginstal dependensi

pip install -q -U torch immutabledict sentencepiece

Mendownload bobot model

# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'

CONFIG = VARIANT[:2]
if CONFIG == '2b':
  CONFIG = '2b-v2'
import os
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-2/pyTorch/gemma-2-{VARIANT}')
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Mendownload implementasi model

# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch')
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

Menyiapkan model

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

Menjalankan inferensi

Berikut adalah contoh untuk membuat dalam mode chat dan membuat dengan beberapa permintaan.

Model Gemma yang disesuaikan dengan petunjuk dilatih dengan pemformat tertentu yang menambahkan anotasi pada contoh penyesuaian petunjuk dengan informasi tambahan, baik selama pelatihan maupun inferensi. Anotasi (1) menunjukkan peran dalam percakapan, dan (2) menunjukkan perubahan dalam percakapan.

Token anotasi yang relevan adalah:

  • user: giliran pengguna
  • model: model turn
  • <start_of_turn>: awal giliran dialog
  • <end_of_turn><eos>: akhir giliran dialog

Untuk mengetahui informasi lebih lanjut, baca artikel tentang pemformatan prompt untuk model Gemma yang disesuaikan dengan petunjuk di sini.

Berikut adalah contoh cuplikan kode yang menunjukkan cara memformat perintah untuk model Gemma yang disesuaikan dengan petunjuk menggunakan template chat pengguna dan model dalam percakapan multi-giliran.

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=128,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

Pelajari lebih lanjut

Setelah mempelajari cara menggunakan Gemma di Pytorch, Anda dapat menjelajahi banyak hal lain yang dapat dilakukan Gemma di ai.google.dev/gemma. Lihat juga referensi terkait lainnya berikut: