Inferencia con Gemma mediante JAX y Flax

Ver en ai.google.dev Ejecutar en Google Colab Abrir en Vertex AI Ver código fuente en GitHub

Descripción general

Gemma es una familia de modelos grandes de lenguaje, ligeros y de vanguardia, basados en la investigación y tecnología de Google DeepMind Gemini. En este instructivo, se muestra cómo realizar inferencias o muestras básicas con el modelo Gemma 2B Instruct a través de la biblioteca gemma de Google DeepMind que se escribió con JAX (una biblioteca de computación numérica de alto rendimiento), Flax (la biblioteca de redes neuronales basada en JAX), Orbax (una biblioteca basada en JAXizer/detoken1 para las utilidades de entrenamiento como token de control) y SentencePiece Aunque Flax no se usa directamente en este cuaderno, se usó para crear Gemma.

Este notebook se puede ejecutar en Google Colab con la GPU T4 gratuita (ve a Editar > Configuración del notebook > En Acelerador de hardware, selecciona GPU T4).

Instalar

1. Configurar el acceso a Kaggle para Gemma

Para completar este instructivo, primero debes seguir las instrucciones de configuración en la configuración de Gemma, que te muestran cómo hacer lo siguiente:

  • Obtén acceso a Gemma en kaggle.com.
  • Selecciona un entorno de ejecución de Colab con recursos suficientes para ejecutar el modelo de Gemma.
  • Genera y configura un nombre de usuario y una clave de API Kaggle.

Después de completar la configuración de Gemma, continúa con la siguiente sección, en la que establecerás variables de entorno para tu entorno de Colab.

2. Configura las variables de entorno

Configura las variables de entorno para KAGGLE_USERNAME y KAGGLE_KEY. Cuando aparezca el mensaje "¿Quieres otorgar acceso?", acepta proporcionar acceso al secreto.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Cómo instalar la biblioteca gemma

Este notebook se enfoca en el uso de una GPU de Colab gratuita. Para habilitar la aceleración de hardware, haz clic en Editar > Configuración del notebook > selecciona GPU T4 > Guardar.

A continuación, debes instalar la biblioteca gemma de Google DeepMind desde github.com/google-deepmind/gemma. Si recibes un error sobre el "agente de resolución de dependencia de pip", generalmente puedes ignorarlo.

pip install -q git+https://github.com/google-deepmind/gemma.git

Carga y prepara el modelo de Gemma

  1. Carga el modelo de Gemma con kagglehub.model_download, que toma tres argumentos:
  • handle: El controlador del modelo de Kaggle
  • path: Es la ruta local (cadena opcional).
  • force_download: Fuerza a volver a descargar el modelo (booleano opcional).
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. Verifica la ubicación de las ponderaciones del modelo y del tokenizador, y luego configura las variables de la ruta de acceso. El directorio del tokenizador estará en el directorio principal en el que descargaste el modelo, mientras que los pesos del modelo estarán en un subdirectorio. Por ejemplo:
  • El archivo tokenizer.model estará en /LOCAL/PATH/TO/gemma/flax/2b-it/2.
  • El punto de control del modelo estará en /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

Realizar inferencias o muestreos

  1. Carga y formatea el punto de control del modelo Gemma con el método gemma.params.load_and_format_params:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. Carga el tokenizador de Gemma, construido con sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Para cargar automáticamente la configuración correcta desde el punto de control del modelo de Gemma, usa gemma.transformer.TransformerConfig. El argumento cache_size es la cantidad de pasos de tiempo en la caché Transformer de Gemma. Luego, crea una instancia del modelo de Gemma como transformer con gemma.transformer.Transformer (que se hereda de flax.linen.Module).
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. Crea un sampler con gemma.sampler.Sampler por encima de los puntos de control o los pesos del modelo de Gemma y el tokenizador:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. Escribe una instrucción en input_batch y realiza inferencias. Puedes modificar total_generation_steps (la cantidad de pasos que se realizan cuando se genera una respuesta; en este ejemplo, se usa 100 para preservar la memoria del host).
prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. (Opcional) Ejecuta esta celda para liberar memoria si completaste el notebook y quieres probar con otra instrucción. Después, puedes volver a crear una instancia de sampler en el paso 3, y personalizar y ejecutar la instrucción en el paso 4.
del sampler

Más información