Ajustar modelos de Gemma en Keras con LoRA

Ver en ai.google.dev Ejecutar en Google Colab Abrir en Vertex AI Ver código fuente en GitHub

Descripción general

Gemma es una familia de modelos abiertos, ligeros y de vanguardia creados a partir de la misma investigación y tecnología que se utilizaron para crear los modelos Gemini.

Se demostró que los modelos grandes de lenguaje (LLM), como Gemma, son eficaces en diversas tareas de PLN. Un LLM se entrena previamente con un gran corpus de texto de forma autosupervisada. El entrenamiento previo ayuda a los LLM a aprender conocimiento de uso general, como las relaciones estadísticas entre palabras. Un LLM se puede ajustar con datos específicos del dominio para realizar tareas downstream (como el análisis de opiniones).

Los LLM son de gran tamaño (parámetros del orden de miles de millones). No se requiere el ajuste fino completo (que actualiza todos los parámetros del modelo) para la mayoría de las aplicaciones, ya que los conjuntos de datos de ajuste fino típicos son relativamente mucho más pequeños que los conjuntos de datos de entrenamiento previo.

La Adaptación de rango bajo (LoRA) es una técnica de ajuste que reduce en gran medida la cantidad de parámetros entrenables para tareas downstream, ya que inmoviliza los pesos del modelo y, luego, inserta una cantidad menor de pesos nuevos en el modelo. Esto hace que el entrenamiento con LoRA sea mucho más rápido y eficiente en la memoria, y produce pesos de modelo más pequeños (cientos de MB), todo mientras se mantiene la calidad de los resultados del modelo.

En este instructivo, se explica cómo usar KerasNLP para realizar el ajuste fino de LoRA en un modelo Gemma 2B con el conjunto de datos Dolly 15K de Databricks. Este conjunto de datos contiene 15,000 pares de instrucciones y respuestas de alta calidad generados por humanos y diseñados específicamente para ajustar los LLM.

Configuración

Obtén acceso a Gemma

Para completar este instructivo, primero deberás completar las instrucciones de configuración de Gemma. En las instrucciones de configuración de Gemma, se muestra cómo hacer lo siguiente:

  • Obtén acceso a Gemma en kaggle.com.
  • Selecciona un entorno de ejecución de Colab con recursos suficientes para ejecutar el modelo de Gemma 2B.
  • Generar y configurar un nombre de usuario Kaggle y una clave de API.

Después de completar la configuración de Gemma, continúa con la siguiente sección, en la que configurarás las variables de entorno de tu entorno de Colab.

Selecciona el entorno de ejecución

Para completar este instructivo, deberás tener un entorno de ejecución de Colab con recursos suficientes para ejecutar el modelo de Gemma. En este caso, puedes usar una GPU T4:

  1. En la esquina superior derecha de la ventana de Colab, selecciona ▾ (Opciones de conexión adicionales).
  2. Selecciona Cambiar el tipo de entorno de ejecución.
  3. En Acelerador de hardware, selecciona GPU T4.

Cómo configurar tu clave de API

Para usar Gemma, debes proporcionar tu nombre de usuario y una clave de API de Kaggle.

Para generar una clave de API de Kaggle, dirígete a la pestaña Cuenta de tu perfil de usuario de Kaggle y selecciona Crear token nuevo. Esto activará la descarga de un archivo kaggle.json que contiene tus credenciales de API.

En Colab, selecciona Secrets (Compose) en el panel izquierdo y agrega tu nombre de usuario de Kaggle y tu clave de API de Kaggle. Almacena tu nombre de usuario con el nombre KAGGLE_USERNAME y tu clave de API con el nombre KAGGLE_KEY.

Configure las variables de entorno

Configura las variables de entorno para KAGGLE_USERNAME y KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Instala dependencias

Instala Keras, KerasNLP y otras dependencias.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

Seleccionar un backend

Keras es una API de aprendizaje profundo de alto nivel y varios frameworks diseñada para brindar simplicidad y facilidad de uso. Con Keras 3, puedes ejecutar flujos de trabajo en uno de los tres backends: TensorFlow, JAX o PyTorch.

En este instructivo, configura el backend para JAX.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Importa paquetes

Importar Keras y KerasNLP.

import keras
import keras_nlp

Cargar conjunto de datos

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

Procesa previamente los datos. En este instructivo, se usa un subconjunto de 1,000 ejemplos de entrenamiento para ejecutar el notebook más rápido. Considera usar más datos de entrenamiento para lograr un ajuste más preciso de alta calidad.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

Carga el modelo

KerasNLP proporciona implementaciones de muchas arquitecturas de modelos populares. En este instructivo, crearás un modelo con GemmaCausalLM, un modelo de Gemma de extremo a extremo para el modelado de lenguaje causal. Un modelo de lenguaje causal predice el siguiente token en función de los tokens anteriores.

Crea el modelo con el método from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

El método from_preset crea una instancia del modelo a partir de una arquitectura y pesos predeterminados. En el código anterior, la cadena "gemma2_2b_en" especifica la arquitectura predeterminada: un modelo de Gemma con 2,000 millones de parámetros.

Inferencia antes del ajuste fino

En esta sección, consultarás el modelo con varias instrucciones para ver cómo responde.

Instrucción de viaje por Europa

Consulta el modelo para obtener sugerencias sobre qué hacer en un viaje a Europa.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

El modelo responde con sugerencias genéricas sobre cómo planificar un viaje.

Instrucción de fotosíntesis para explicar como a un niño

Indícale al modelo que explique la fotosíntesis en términos lo suficientemente sencillos como para que un niño de 5 años pueda comprenderlos.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

La respuesta del modelo contiene palabras que podrían no ser fáciles de entender para un niño, como clorofila.

Ajuste fino de LoRA

Para obtener mejores respuestas del modelo, ajusta el modelo con la adaptación de clasificación baja (LoRA) con el conjunto de datos Dolly 15K de Databricks.

La clasificación de LoRA determina la dimensionalidad de las matrices entrenables que se agregan a las ponderaciones originales del LLM. Controla la expresividad y la precisión de los ajustes de ajuste fino.

Una clasificación más alta significa posibles cambios más detallados, pero también significa más parámetros entrenables. Una clasificación más baja implica menos sobrecarga computacional, pero una adaptación potencialmente menos precisa.

En este instructivo, se usa una clasificación de LoRA de 4. En la práctica, comienza con un rango relativamente pequeño (como 4, 8 o 16). Esto es eficiente en términos de procesamiento para la experimentación. Entrena tu modelo con esta clasificación y evalúa la mejora del rendimiento en tu tarea. Aumenta gradualmente la clasificación en las pruebas posteriores y comprueba si eso mejora aún más el rendimiento.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Ten en cuenta que habilitar LoRA reduce significativamente la cantidad de parámetros entrenables (de 2,600 millones a 2,9 millones).

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

Nota sobre el ajuste fino de precisión mixta en GPUs NVIDIA

Se recomienda la precisión completa para el ajuste fino. Cuando ajustes las GPU de NVIDIA, ten en cuenta que puedes usar la precisión mixta (keras.mixed_precision.set_global_policy('mixed_bfloat16')) para acelerar el entrenamiento con un efecto mínimo en su calidad. El ajuste fino de precisión mixta consume más memoria, por lo que solo es útil en GPUs más grandes.

Para la inferencia, la precisión media (keras.config.set_floatx("bfloat16")) funcionará y ahorrará memoria, mientras que la precisión mixta no es aplicable.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Inferencia después del ajuste

Después de ajustarlas, las respuestas siguen las instrucciones proporcionadas en la instrucción.

Instrucción de viaje por Europa

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

El modelo ahora recomienda lugares para visitar en Europa.

Instrucción de fotosíntesis para explicar como a un niño

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

Ahora, el modelo explica la fotosíntesis de forma más sencilla.

Ten en cuenta que, a modo de demostración, este instructivo ajusta el modelo en un subconjunto pequeño del conjunto de datos para una sola época y con un valor de clasificación de LoRA bajo. Para obtener mejores respuestas del modelo ajustado, puedes experimentar con lo siguiente:

  1. Aumenta el tamaño del conjunto de datos de ajuste
  2. Entrenamiento para más pasos (períodos)
  3. Cómo establecer un rango de LoRA más alto
  4. Modificar los valores de hiperparámetro, como learning_rate y weight_decay

Resumen y próximos pasos

En este instructivo, se explicó el ajuste de LoRA en un modelo de Gemma con KerasNLP. Consulta los siguientes documentos: