Inferencia con RecurrentGemma mediante JAX y Flax

Ver en ai.google.dev Ejecutar en Google Colab Abrir en Vertex AI Ver el código fuente en GitHub

En este instructivo, se muestra cómo realizar muestreos o inferencias básicos con el modelo RecurrentGemma 2B Instruct mediante la biblioteca recurrentgemma de Google DeepMind que se escribió con JAX (una biblioteca de computación numérica de alto rendimiento), Flax (la biblioteca de red neuronal basada en JAX), Orbax (una biblioteca basada en JAX para usar la biblioteca recurrentgemma de Google DeepMind), que se escribió con JAX (una biblioteca de computación numérica de alto rendimiento), Flax (la biblioteca de red neuronal basada en JAX), Orbax (una biblioteca basada en JAX para usar la biblioteca de JAX(una biblioteca basada en JAX para herramientas de entrenamiento, como checkpointing de Pieaceent}).SentencePiece Aunque Flax no se usa directamente en este notebook, se usó para crear Gemma y RecurrentGemma (el modelo de Griffin).

Este notebook se puede ejecutar en Google Colab con la GPU T4 (ve a Editar > Configuración del notebook > en Acelerador de hardware, selecciona GPU T4).

Configuración

En las siguientes secciones, se explican los pasos para preparar un notebook para usar un modelo RecurrentGemma, incluido el acceso al modelo, la obtención de una clave de API y la configuración del entorno de ejecución del notebook.

Configurar acceso a Kaggle para Gemma

Para completar este instructivo, primero debes seguir las instrucciones de configuración similares a la configuración de Gemma, con algunas excepciones:

  • Obtén acceso a RecurrentGemma (en lugar de Gemma) en kaggle.com.
  • Selecciona un entorno de ejecución de Colab con recursos suficientes para ejecutar el modelo RecurrentGemma.
  • Generar y configurar un nombre de usuario Kaggle y una clave de API.

Después de completar la configuración de RecurrentGemma, continúa con la siguiente sección, en la que establecerás variables de entorno para tu entorno de Colab.

Configure las variables de entorno

Configura variables de entorno para KAGGLE_USERNAME y KAGGLE_KEY. Cuando aparezca el mensaje “¿Quieres otorgar acceso?” mensajes, acepta proporcionar acceso al Secret.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Instala la biblioteca recurrentgemma

Este notebook se enfoca en el uso de una GPU de Colab gratuita. Para habilitar la aceleración de hardware, haz clic en Editar > Configuración del notebook > Selecciona GPU T4 > Guardar.

A continuación, debes instalar la biblioteca recurrentgemma de Google DeepMind desde github.com/google-deepmind/recurrentgemma. Si recibes un error sobre el “agente de resolución de dependencias de pip”, por lo general, puedes ignorarlo.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

Carga y prepara el modelo RecurrentGemma

  1. Carga el modelo RecurrentGemma con kagglehub.model_download, que toma tres argumentos:
  • handle: el controlador del modelo de Kaggle
  • path: (Cadena opcional) es la ruta de acceso local.
  • force_download: (booleano opcional) Obliga a volver a descargar el modelo.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. Verifica la ubicación de los pesos del modelo y el tokenizador, luego establece las variables de la ruta de acceso. El directorio del tokenizador estará en el directorio principal donde descargaste el modelo, mientras que los pesos del modelo estarán en un subdirectorio. Por ejemplo:
  • El archivo tokenizer.model estará en /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1).
  • El punto de control del modelo estará en /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Realizar muestreo o inferencia

  1. Carga el punto de control del modelo RecurrentGemma con el método recurrentgemma.jax.load_parameters. El argumento sharding establecido en "single_device" carga todos los parámetros del modelo en un solo dispositivo.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. Carga el tokenizador del modelo RecurrentGemma, construido con sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Para cargar automáticamente la configuración correcta desde el punto de control del modelo RecurrentGemma, usa recurrentgemma.GriffinConfig.from_flax_params_or_variables. Luego, crea una instancia del modelo Griffin con recurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. Crea un sampler con recurrentgemma.jax.Sampler sobre el punto de control o los pesos del modelo RecurrentGemma y el tokenizador:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. Escribe una instrucción en prompt y realiza inferencias. Puedes modificar total_generation_steps (la cantidad de pasos que se realizan cuando se genera una respuesta; en este ejemplo, se usa 50 para preservar la memoria del host).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

Más información