Gemma in PyTorch

Visualizza su ai.google.dev Esegui in Google Colab Visualizza il codice sorgente su GitHub

Questa è una breve demo dell'esecuzione dell'inferenza di Gemma in PyTorch. Per ulteriori dettagli, consulta il repository GitHub dell'implementazione ufficiale di PyTorch qui.

Tieni presente che:

  • Il runtime Python della CPU Colab senza costi e il runtime Python della GPU T4 sono sufficienti per eseguire i modelli Gemma 2B e i 7 miliardi di modelli quantizzati int8.
  • Per casi d'uso avanzati per altre GPU o TPU, consulta il file README.md nel repository ufficiale.

1. Configurare l'accesso a Kaggle per Gemma

Per completare questo tutorial, devi prima seguire le istruzioni di configurazione nella configurazione di Gemma, che mostrano come fare:

  • Accedi a Gemma su kaggle.com.
  • Seleziona un runtime Colab con risorse sufficienti per eseguire il modello Gemma.
  • Genera e configura un nome utente e una chiave API Kaggle.

Dopo aver completato la configurazione di Gemma, vai alla sezione successiva, dove imposterai le variabili di ambiente per l'ambiente Colab.

2. Imposta le variabili di ambiente

Imposta le variabili di ambiente per KAGGLE_USERNAME e KAGGLE_KEY. Quando viene visualizzato il messaggio "Vuoi concedere l'accesso?", accetta di fornire l'accesso al secret.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installa le dipendenze

pip install -q -U torch immutabledict sentencepiece

Scarica i pesi del modello

# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'

CONFIG = VARIANT[:2]
if CONFIG == '2b':
  CONFIG = '2b-v2'
import os
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-2/pyTorch/gemma-2-{VARIANT}')
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Scarica l'implementazione del modello

# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch')
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

Configura il modello

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

Esegui l'inferenza

Di seguito sono riportati esempi di generazione in modalità di chat e con più richieste.

I modelli Gemma ottimizzati per le istruzioni sono stati addestrati con uno specifico formattatore che annota esempi di ottimizzazione delle istruzioni con informazioni aggiuntive, sia durante l'addestramento che l'inferenza. Le annotazioni (1) indicano i ruoli in una conversazione e (2) delimitano i turni in una conversazione.

I token di annotazione pertinenti sono:

  • user: turno dell'utente
  • model: turno del modello
  • <start_of_turn>: inizio del turno di dialogo
  • <end_of_turn><eos>: fine del turno di dialogo

Per ulteriori informazioni, consulta la pagina relativa alla formattazione dei prompt per i modelli Gemma ottimizzati per le istruzioni qui.

Di seguito è riportato uno snippet di codice di esempio che mostra come formattare un prompt per un modello Gemma ottimizzato in base alle istruzioni utilizzando modelli di chat per utente e modello in una conversazione con più turni.

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=128,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

Scopri di più

Ora che hai imparato a utilizzare Gemma in Pytorch, puoi esplorare le molte altre funzionalità di Gemma all'indirizzo ai.google.dev/gemma. Consulta anche queste altre risorse correlate: