Ver en ai.google.dev | Ejecutar en Google Colab | Ver código fuente en GitHub |
Esta es una demostración rápida de la ejecución de la inferencia de Gemma en PyTorch. Para obtener más detalles, consulta el repositorio de GitHub de la implementación oficial de PyTorch aquí.
Ten en cuenta lo siguiente:
- El entorno de ejecución gratuito de Python para la CPU de Colab y el entorno de ejecución de Python T4 en GPU son suficientes para ejecutar los modelos Gemma 2B y los modelos cuantizados 7B int8.
- Para ver casos de uso avanzados de otras GPUs o TPU, consulta README.md en el repositorio oficial.
1. Configura el acceso de Kaggle para Gemma
Para completar este instructivo, primero debes seguir las instrucciones de configuración en Configuración de Gemma, que te muestran cómo hacer lo siguiente:
- Obtén acceso a Gemma en kaggle.com.
- Selecciona un entorno de ejecución de Colab con recursos suficientes para ejecutar el modelo de Gemma.
- Genera y configura un nombre de usuario y una clave de API de Kaggle.
Después de completar la configuración de Gemma, continúa con la siguiente sección, en la que configurarás las variables de entorno de tu entorno de Colab.
2. Configure las variables de entorno
Configura las variables de entorno para KAGGLE_USERNAME
y KAGGLE_KEY
. Cuando se te solicite con los mensajes "¿Quieres otorgar acceso?", acepta proporcionar acceso secreto.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Instala dependencias
pip install -q -U torch immutabledict sentencepiece
Descarga los pesos del modelo
# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT[:2]
if CONFIG == '2b':
CONFIG = '2b-v2'
import os
import kagglehub
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-2/pyTorch/gemma-2-{VARIANT}')
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
Descarga la implementación del modelo
# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'... remote: Enumerating objects: 239, done. remote: Counting objects: 100% (123/123), done. remote: Compressing objects: 100% (68/68), done. remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116 Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done. Resolving deltas: 100% (135/135), done.
import sys
sys.path.append('gemma_pytorch')
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch
Configura el modelo
# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT
# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()
Ejecuta la inferencia
A continuación, se incluyen ejemplos para generar en modo de chat y generar con varias solicitudes.
Los modelos de Gemma ajustados a instrucciones se entrenaron con un formateador específico que anota ejemplos de ajuste de instrucciones con información adicional, tanto durante el entrenamiento como la inferencia. Las anotaciones (1) indican los roles en una conversación y (2) delinean turnos en una conversación.
Los tokens de anotación relevantes son los siguientes:
user
: turno de usuariomodel
: giro de modelo<start_of_turn>
: turno de inicio del diálogo<end_of_turn><eos>
: Fin del turno de diálogo
Para obtener más información, lee sobre el formato de instrucciones para modelos de Gemma ajustados a instrucciones aquí.
El siguiente es un fragmento de código de muestra que demuestra cómo dar formato a una instrucción para un modelo de Gemma ajustado a instrucciones con plantillas de chat de usuario y modelo en una conversación de varios turnos.
# Generate with one request in chat mode
# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"
# Sample formatted prompt
prompt = (
USER_CHAT_TEMPLATE.format(
prompt='What is a good place for travel in the US?'
)
+ MODEL_CHAT_TEMPLATE.format(prompt='California.')
+ USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
+ '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)
model.generate(
USER_CHAT_TEMPLATE.format(prompt=prompt),
device=device,
output_len=128,
)
Chat prompt: <start_of_turn>user What is a good place for travel in the US?<end_of_turn><eos> <start_of_turn>model California.<end_of_turn><eos> <start_of_turn>user What can I do in California?<end_of_turn><eos> <start_of_turn>model "California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo) \n\nThe more you tell me, the better recommendations I can give! 😊 \n<end_of_turn>"
# Generate sample
model.generate(
'Write a poem about an llm writing a poem.',
device=device,
output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"
Más información
Ahora que aprendiste a usar Gemma en Pytorch, puedes explorar las muchas otras funciones que puede realizar en ai.google.dev/gemma. Consulta también estos otros recursos relacionados: