Ver en ai.google.dev | Ejecutar en Google Colab | Abrir en Vertex AI | Ver el código fuente en GitHub |
En este instructivo, se muestra cómo realizar muestreos o inferencias básicos con el modelo RecurrentGemma 2B Instruct mediante la biblioteca recurrentgemma
de Google DeepMind que se escribió con JAX (una biblioteca de computación numérica de alto rendimiento), Flax (la biblioteca de red neuronal basada en JAX), Orbax (una biblioteca basada en JAX para usar la biblioteca recurrentgemma
de Google DeepMind), que se escribió con JAX (una biblioteca de computación numérica de alto rendimiento), Flax (la biblioteca de red neuronal basada en JAX), Orbax (una biblioteca basada en JAX para usar la biblioteca de JAX(una biblioteca basada en JAX para herramientas de entrenamiento, como checkpointing de Pieaceent}).SentencePiece Aunque Flax no se usa directamente en este notebook, se usó para crear Gemma y RecurrentGemma (el modelo de Griffin).
Este notebook se puede ejecutar en Google Colab con la GPU T4 (ve a Editar > Configuración del notebook > en Acelerador de hardware, selecciona GPU T4).
Configuración
En las siguientes secciones, se explican los pasos para preparar un notebook para usar un modelo RecurrentGemma, incluido el acceso al modelo, la obtención de una clave de API y la configuración del entorno de ejecución del notebook.
Configurar acceso a Kaggle para Gemma
Para completar este instructivo, primero debes seguir las instrucciones de configuración similares a la configuración de Gemma, con algunas excepciones:
- Obtén acceso a RecurrentGemma (en lugar de Gemma) en kaggle.com.
- Selecciona un entorno de ejecución de Colab con recursos suficientes para ejecutar el modelo RecurrentGemma.
- Generar y configurar un nombre de usuario Kaggle y una clave de API.
Después de completar la configuración de RecurrentGemma, continúa con la siguiente sección, en la que establecerás variables de entorno para tu entorno de Colab.
Configure las variables de entorno
Configura variables de entorno para KAGGLE_USERNAME
y KAGGLE_KEY
. Cuando aparezca el mensaje “¿Quieres otorgar acceso?” mensajes, acepta proporcionar acceso al Secret.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Instala la biblioteca recurrentgemma
Este notebook se enfoca en el uso de una GPU de Colab gratuita. Para habilitar la aceleración de hardware, haz clic en Editar > Configuración del notebook > Selecciona GPU T4 > Guardar.
A continuación, debes instalar la biblioteca recurrentgemma
de Google DeepMind desde github.com/google-deepmind/recurrentgemma
. Si recibes un error sobre el “agente de resolución de dependencias de pip”, por lo general, puedes ignorarlo.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
Carga y prepara el modelo RecurrentGemma
- Carga el modelo RecurrentGemma con
kagglehub.model_download
, que toma tres argumentos:
handle
: el controlador del modelo de Kagglepath
: (Cadena opcional) es la ruta de acceso local.force_download
: (booleano opcional) Obliga a volver a descargar el modelo.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- Verifica la ubicación de los pesos del modelo y el tokenizador, luego establece las variables de la ruta de acceso. El directorio del tokenizador estará en el directorio principal donde descargaste el modelo, mientras que los pesos del modelo estarán en un subdirectorio. Por ejemplo:
- El archivo
tokenizer.model
estará en/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
). - El punto de control del modelo estará en
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Realizar muestreo o inferencia
- Carga el punto de control del modelo RecurrentGemma con el método
recurrentgemma.jax.load_parameters
. El argumentosharding
establecido en"single_device"
carga todos los parámetros del modelo en un solo dispositivo.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- Carga el tokenizador del modelo RecurrentGemma, construido con
sentencepiece.SentencePieceProcessor
:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- Para cargar automáticamente la configuración correcta desde el punto de control del modelo RecurrentGemma, usa
recurrentgemma.GriffinConfig.from_flax_params_or_variables
. Luego, crea una instancia del modelo Griffin conrecurrentgemma.jax.Griffin
.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- Crea un
sampler
conrecurrentgemma.jax.Sampler
sobre el punto de control o los pesos del modelo RecurrentGemma y el tokenizador:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- Escribe una instrucción en
prompt
y realiza inferencias. Puedes modificartotal_generation_steps
(la cantidad de pasos que se realizan cuando se genera una respuesta; en este ejemplo, se usa50
para preservar la memoria del host).
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
Más información
- Puedes obtener más información sobre la biblioteca
recurrentgemma
de Google DeepMind en GitHub, que contiene docstrings de métodos y módulos que usaste en este instructivo, comorecurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
yrecurrentgemma.jax.Sampler
. - Las siguientes bibliotecas tienen sus propios sitios de documentación: core JAX, Flax y Orbax.
- Para ver la documentación del tokenizador/detokenizador
sentencepiece
, consulta el repositorio de GitHubsentencepiece
de Google. - Para ver la documentación de
kagglehub
, consultaREADME.md
en el repositorio de GitHubkagglehub
de Kaggle. - Aprende a usar modelos de Gemma con Vertex AI de Google Cloud.
- Consulta RecurrentGemma: Moving Past Transformers para modelos de lenguaje abierto eficientes de Google DeepMind.
- Lee el Griffin: Cómo mezclar recurrencias lineales cerradas con Atención local para modelos de lenguaje eficientes de GoogleDeepMind para obtener más información sobre la arquitectura de modelos que usa RecurrentGemma.