Questa guida mostra come eseguire Gemma utilizzando il framework PyTorch, incluso come utilizzare i dati delle immagini per richiedere i modelli Gemma release 3 e versioni successive. Per maggiori dettagli sull'implementazione di Gemma PyTorch, consulta il file README del repository del progetto.
Configurazione
Le sezioni seguenti spiegano come configurare l'ambiente di sviluppo, ad esempio come accedere ai modelli Gemma per il download da Kaggle, impostare le variabili di autenticazione, installare le dipendenze e importare i pacchetti.
Requisiti di sistema
Questa libreria Pytorch di Gemma richiede processori GPU o TPU per eseguire il modello Gemma. Il runtime Python per CPU Colab standard e il runtime Python per GPU T4 sono sufficienti per eseguire modelli Gemma di dimensioni 1B, 2B e 4B. Per casi d'uso avanzati per altre GPU o TPU, consulta il file README nel repository Gemma PyTorch.
Accedere a Gemma su Kaggle
Per completare questo tutorial, devi prima seguire le istruzioni di configurazione riportate nella pagina Configurazione di Gemma, che spiegano come svolgere le seguenti operazioni:
- Accedi a Gemma su kaggle.com.
- Seleziona un runtime Colab con risorse sufficienti per eseguire il modello Gemma.
- Genera e configura un nome utente e una chiave API Kaggle.
Dopo aver completato la configurazione di Gemma, vai alla sezione successiva, dove imposterai le variabili di ambiente per il tuo ambiente Colab.
Imposta le variabili di ambiente
Imposta le variabili di ambiente per KAGGLE_USERNAME
e KAGGLE_KEY
. Quando viene visualizzato il messaggio "Concedere l'accesso?", accetta di fornire l'accesso ai secret.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Installa le dipendenze
pip install -q -U torch immutabledict sentencepiece
Scarica i pesi del modello
# Choose variant and machine type
VARIANT = '4b-it'
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT[:2]
if CONFIG == '4b':
CONFIG = '4b-v1'
import kagglehub
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')
Imposta i percorsi del tokenizzatore e del checkpoint per il modello.
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
Configura l'ambiente di esecuzione
Le sezioni seguenti spiegano come preparare un ambiente PyTorch per l'esecuzione di Gemma.
Prepara l'ambiente di esecuzione di PyTorch
Prepara l'ambiente di esecuzione del modello PyTorch clonando il repository Gemma Pytorch.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'... remote: Enumerating objects: 239, done. remote: Counting objects: 100% (123/123), done. remote: Compressing objects: 100% (68/68), done. remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116 Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done. Resolving deltas: 100% (135/135), done.
import sys
sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM
import os
import torch
Imposta la configurazione del modello
Prima di eseguire il modello, devi impostare alcuni parametri di configurazione, tra cui la variante Gemma, lo tokenizer e il livello di quantizzazione.
# Set up model config.
model_config = get_model_config(VARIANT)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path
Configura il contesto del dispositivo
Il seguente codice configura il contesto del dispositivo per l'esecuzione del modello:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
Esegui l'inizializzazione e carica il modello
Carica il modello con i relativi pesi per prepararti a eseguire le richieste.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
model = Gemma3ForMultimodalLM(model_config)
model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
model = model.to(device).eval()
print("Model loading done.")
print('Generating requests in chat mode...')
Esegui l'inferenza
Di seguito sono riportati esempi di generazione in modalità chat e con più richieste.
I modelli Gemma ottimizzati per le istruzioni sono stati addestrati con un formattatore specifico che annota gli esempi di ottimizzazione delle istruzioni con informazioni aggiuntive, sia durante l'addestramento sia durante l'inferenza. Le annotazioni (1) indicano i ruoli in una conversazione e (2) delineano i turni in una conversazione.
I token di annotazione pertinenti sono:
user
: turno dell'utentemodel
: turno del modello<start_of_turn>
: inizio del turno di dialogo<start_of_image>
: tag per l'inserimento dei dati delle immagini<end_of_turn><eos>
: fine del turno di dialogo
Per ulteriori informazioni, leggi la pagina sulla formattazione dei prompt per i modelli Gemma ottimizzati per le istruzioni [qui](https://ai.google.dev/gemma/core/prompt-structure
Generare testo con testo
Di seguito è riportato uno snippet di codice di esempio che mostra come formattare un prompt per un modello Gemma ottimizzato per le istruzioni utilizzando i modelli di chat dell'utente e del modello in una conversazione con più turni.
# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"
# Sample formatted prompt
prompt = (
USER_CHAT_TEMPLATE.format(
prompt='What is a good place for travel in the US?'
)
+ MODEL_CHAT_TEMPLATE.format(prompt='California.')
+ USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
+ '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)
model.generate(
USER_CHAT_TEMPLATE.format(prompt=prompt),
device=device,
output_len=256,
)
Chat prompt: <start_of_turn>user What is a good place for travel in the US?<end_of_turn><eos> <start_of_turn>model California.<end_of_turn><eos> <start_of_turn>user What can I do in California?<end_of_turn><eos> <start_of_turn>model "California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo) \n\nThe more you tell me, the better recommendations I can give! 😊 \n<end_of_turn>"
# Generate sample
model.generate(
'Write a poem about an llm writing a poem.',
device=device,
output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"
Generare testo con immagini
Con la versione 3 di Gemma e versioni successive, puoi utilizzare le immagini con il prompt. L'esempio seguente mostra come includere dati visivi nel prompt.
print('Chat with images...\n')
def read_image(url):
import io
import requests
import PIL
contents = io.BytesIO(requests.get(url).content)
return PIL.Image.open(contents)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
image = read_image(image_url)
print(model.generate(
[['<start_of_turn>user\n',image, 'What animal is in this image?<end_of_turn>\n', '<start_of_turn>model\n']],
device=device,
output_len=OUTPUT_LEN,
))
Scopri di più
Ora che hai imparato a utilizzare Gemma in Pytorch, puoi esplorare le molte altre funzionalità di Gemma all'indirizzo ai.google.dev/gemma. Consulta anche queste altre risorse correlate: