En esta guía, se muestra cómo ejecutar Gemma con el framework de PyTorch, incluido cómo usar los datos de imagen para solicitar modelos de Gemma versión 3 y posteriores. Para obtener más detalles sobre la implementación de Gemma PyTorch, consulta el archivo README del repositorio del proyecto.
Configuración
En las siguientes secciones, se explica cómo configurar tu entorno de desarrollo, lo que incluye cómo obtener acceso a los modelos de Gemma para descargarlos de Kaggle, establecer variables de autenticación, instalar dependencias y, también, importar paquetes.
Requisitos del sistema
Esta biblioteca de Pytorch de Gemma requiere procesadores de GPU o TPU para ejecutar el modelo de Gemma. El entorno de ejecución de Python estándar de la CPU de Colab y el entorno de ejecución de Python de la GPU T4 son suficientes para ejecutar modelos de Gemma de 1,000 millones, 2,000 millones y 4,000 millones de parámetros. Para ver casos de uso avanzados de otras GPUs o TPU, consulta el archivo README en el repositorio de Gemma PyTorch.
Obtén acceso a Gemma en Kaggle
Para completar este instructivo, primero debes seguir las instrucciones de configuración que se indican en Configuración de Gemma, en las que se muestra cómo hacer lo siguiente:
- Obtén acceso a Gemma en kaggle.com.
- Selecciona un entorno de ejecución de Colab con recursos suficientes para ejecutar el modelo de Gemma.
- Genera y configura un nombre de usuario y una clave de API de Kaggle.
Después de completar la configuración de Gemma, continúa con la siguiente sección, en la que configurarás las variables de entorno de tu entorno de Colab.
Configure las variables de entorno
Configura las variables de entorno para KAGGLE_USERNAME
y KAGGLE_KEY
. Cuando se te solicite con los mensajes "¿Quieres otorgar acceso?", acepta proporcionar acceso secreto.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Instala dependencias
pip install -q -U torch immutabledict sentencepiece
Descarga los pesos del modelo
# Choose variant and machine type
VARIANT = '4b-it'
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT[:2]
if CONFIG == '4b':
CONFIG = '4b-v1'
import kagglehub
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')
Establece las rutas de acceso del tokenizador y del punto de control para el modelo.
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
Configura el entorno de ejecución
En las siguientes secciones, se explica cómo preparar un entorno de PyTorch para ejecutar Gemma.
Prepara el entorno de ejecución de PyTorch
Clona el repositorio de Pytorch de Gemma para preparar el entorno de ejecución del modelo de PyTorch.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'... remote: Enumerating objects: 239, done. remote: Counting objects: 100% (123/123), done. remote: Compressing objects: 100% (68/68), done. remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116 Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done. Resolving deltas: 100% (135/135), done.
import sys
sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM
import os
import torch
Establece la configuración del modelo
Antes de ejecutar el modelo, debes establecer algunos parámetros de configuración, como la variante de Gemma, el tokenizador y el nivel de cuantificación.
# Set up model config.
model_config = get_model_config(VARIANT)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path
Configura el contexto del dispositivo
El siguiente código configura el contexto del dispositivo para ejecutar el modelo:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
Crea una instancia del modelo y cárgalo
Carga el modelo con sus pesos para prepararte para ejecutar solicitudes.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
model = Gemma3ForMultimodalLM(model_config)
model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
model = model.to(device).eval()
print("Model loading done.")
print('Generating requests in chat mode...')
Ejecuta la inferencia
A continuación, se incluyen ejemplos para generar en modo de chat y generar con varias solicitudes.
Los modelos de Gemma ajustados por instrucciones se entrenaron con un formato específico que anota ejemplos de ajuste de instrucciones con información adicional, tanto durante el entrenamiento como la inferencia. Las anotaciones (1) indican los roles en una conversación y (2) demarcan los turnos en una conversación.
Los tokens de anotación relevantes son los siguientes:
user
: turno del usuariomodel
: giro del modelo<start_of_turn>
: Inicio del turno de diálogo<start_of_image>
: Etiqueta para la entrada de datos de imagen<end_of_turn><eos>
: Fin del turno de diálogo
Para obtener más información, lee sobre el formato de instrucciones para modelos de Gemma ajustados a instrucciones [aquí](https://ai.google.dev/gemma/core/prompt-structure
Genera texto con texto
El siguiente es un fragmento de código de muestra que muestra cómo dar formato a una instrucción para un modelo de Gemma ajustado con instrucciones mediante plantillas de chat del usuario y del modelo en una conversación de varias intervenciones.
# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"
# Sample formatted prompt
prompt = (
USER_CHAT_TEMPLATE.format(
prompt='What is a good place for travel in the US?'
)
+ MODEL_CHAT_TEMPLATE.format(prompt='California.')
+ USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
+ '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)
model.generate(
USER_CHAT_TEMPLATE.format(prompt=prompt),
device=device,
output_len=256,
)
Chat prompt: <start_of_turn>user What is a good place for travel in the US?<end_of_turn><eos> <start_of_turn>model California.<end_of_turn><eos> <start_of_turn>user What can I do in California?<end_of_turn><eos> <start_of_turn>model "California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo) \n\nThe more you tell me, the better recommendations I can give! 😊 \n<end_of_turn>"
# Generate sample
model.generate(
'Write a poem about an llm writing a poem.',
device=device,
output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"
Genera texto con imágenes
Con la versión 3 de Gemma y versiones posteriores, puedes usar imágenes con tu instrucción. En el siguiente ejemplo, se muestra cómo incluir datos visuales con tu instrucción.
print('Chat with images...\n')
def read_image(url):
import io
import requests
import PIL
contents = io.BytesIO(requests.get(url).content)
return PIL.Image.open(contents)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
image = read_image(image_url)
print(model.generate(
[['<start_of_turn>user\n',image, 'What animal is in this image?<end_of_turn>\n', '<start_of_turn>model\n']],
device=device,
output_len=OUTPUT_LEN,
))
Más información
Ahora que aprendiste a usar Gemma en Pytorch, puedes explorar las muchas otras funciones que puede realizar en ai.google.dev/gemma. Consulta también estos otros recursos relacionados: