Ajuster des modèles Gemma dans Keras à l'aide de LoRA

Afficher sur ai.google.dev Exécuter dans Google Colab Ouvrir dans Vertex AI Afficher la source sur GitHub

Présentation

Gemma est une famille de modèles ouverts légers et de pointe, élaborés à partir des recherches et des technologies utilisées pour créer les modèles Gemini.

Les grands modèles de langage (LLM) comme Gemma ont été démontrés comme efficaces pour diverses tâches de PLN. Un LLM est d'abord pré-entraîné sur un grand corpus de texte de manière autosupervisée. Le pré-entraînement aide les LLM à apprendre des connaissances générales, telles que les relations statistiques entre les mots. Un LLM peut ensuite être affiné avec des données spécifiques au domaine pour effectuer des tâches en aval (telles que l'analyse des sentiments).

Les LLM sont extrêmement volumineux (paramètres de l'ordre de milliards). L'ajustement fin complet (qui met à jour tous les paramètres du modèle) n'est pas nécessaire pour la plupart des applications, car les ensembles de données d'ajustement fin typiques sont relativement beaucoup plus petits que les ensembles de données de pré-entraînement.

L'adaptation de faible classement (LoRA) est une technique d'affinage qui réduit considérablement le nombre de paramètres pouvant être entraînés pour les tâches en aval en gelant les pondérations du modèle et en insérant un plus petit nombre de nouvelles pondérations dans le modèle. L'entraînement avec LoRA est ainsi beaucoup plus rapide et plus efficace en termes de mémoire, et produit des poids de modèle plus petits (quelques centaines de mégaoctets), tout en conservant la qualité des sorties du modèle.

Ce tutoriel vous explique comment utiliser KerasNLP pour effectuer un réglage LoRA sur un modèle Gemma 2B à l'aide de l'ensemble de données Databricks Dolly 15k. Cet ensemble de données contient 15 000 paires de requêtes / réponses de haute qualité générées par l'humain,spécialement conçues pour affiner les LLM.

Configuration

Accéder à Gemma

Pour suivre ce tutoriel, vous devez d'abord suivre les instructions de configuration dans Configuration de Gemma. Les instructions de configuration de Gemma vous expliquent comment procéder comme suit:

  • Accédez à Gemma sur kaggle.com.
  • Sélectionnez un environnement d'exécution Colab disposant de ressources suffisantes pour exécuter le modèle Gemma 2B.
  • Générez et configurez un nom d'utilisateur et une clé API Kaggle.

Une fois la configuration de Gemma terminée, passez à la section suivante, où vous définirez des variables d'environnement pour votre environnement Colab.

Sélectionner l'environnement d'exécution

Pour suivre ce tutoriel, vous devez disposer d'un environnement d'exécution Colab avec suffisamment de ressources pour exécuter le modèle Gemma. Dans ce cas, vous pouvez utiliser un GPU T4:

  1. En haut à droite de la fenêtre Colab, sélectionnez ▾ (Options de connexion supplémentaires).
  2. Sélectionnez Modifier le type d'environnement d'exécution.
  3. Sous Hardware accelerator (Accélérateur matériel), sélectionnez T4 GPU (GPU T4).

Configurer votre clé API

Pour utiliser Gemma, vous devez fournir votre nom d'utilisateur Kaggle et une clé d'API Kaggle.

Pour générer une clé API Kaggle, accédez à l'onglet Compte de votre profil utilisateur Kaggle et sélectionnez Créer un jeton. Cette opération déclenchera le téléchargement d'un fichier kaggle.json contenant vos identifiants pour l'API.

Dans Colab, sélectionnez Secrets (🔑) dans le volet de gauche, puis ajoutez votre nom d'utilisateur Kaggle et votre clé API Kaggle. Stockez votre nom d'utilisateur sous le nom KAGGLE_USERNAME et votre clé API sous le nom KAGGLE_KEY.

Définir des variables d'environnement

Définissez les variables d'environnement pour KAGGLE_USERNAME et KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installer des dépendances

Installez Keras, KerasNLP et d'autres dépendances.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

Sélectionnez un backend

Keras est une API de deep learning multi-framework de haut niveau conçue pour être simple et facile à utiliser. Avec Keras 3, vous pouvez exécuter des workflows sur l'un des trois backends suivants: TensorFlow, JAX ou PyTorch.

Pour ce tutoriel, configurez le backend pour JAX.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Importer des packages

Importez Keras et KerasNLP.

import keras
import keras_nlp

Charger l'ensemble de données

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

Prétraitez les données. Ce tutoriel utilise un sous-ensemble de 1 000 exemples d'entraînement pour exécuter le notebook plus rapidement. Envisagez d'utiliser davantage de données d'entraînement pour un réglage plus précis.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

Charger le modèle

KerasNLP fournit des implémentations de nombreuses architectures de modèles populaires. Dans ce tutoriel, vous allez créer un modèle à l'aide de GemmaCausalLM, un modèle Gemma de bout en bout pour la modélisation causale du langage. Un modèle de langage causal prédit le jeton suivant en fonction des jetons précédents.

Créez le modèle à l'aide de la méthode from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

La méthode from_preset instancie le modèle à partir d'une architecture et de poids prédéfinis. Dans le code ci-dessus, la chaîne "gemma2_2b_en" spécifie l'architecture prédéfinie, un modèle Gemma avec deux milliards de paramètres.

Inférence avant l'ajustement

Dans cette section, vous allez interroger le modèle à l'aide de différentes requêtes pour voir comment il répond.

Requête de voyage en Europe

Interrogez le modèle pour obtenir des suggestions sur ce qu'il faut faire lors d'un voyage en Europe.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

Le modèle répond par des conseils génériques sur la façon de planifier un trajet.

Invite de photosynthèse ELI5

Demandez au modèle d'expliquer la photosynthèse de façon suffisamment simple pour qu'un enfant de cinq ans puisse les comprendre.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

La réponse du modèle contient des mots qui peuvent être difficiles à comprendre pour un enfant, comme "chlorophylle".

Affinage LoRA

Pour obtenir de meilleures réponses du modèle, affinez-le avec l'adaptation à faible rang (LoRA) à l'aide de l'ensemble de données Dolly 15 000 de Databricks.

Le rang LoRA détermine la dimensionnalité des matrices entraînables qui sont ajoutées aux pondérations d'origine du LLM. Il contrôle l'expressivité et la précision des ajustements de réglage fin.

Un rang plus élevé signifie que des modifications plus détaillées sont possibles, mais aussi que vous pouvez entraîner plus de paramètres. Un rang inférieur signifie moins de frais de calcul, mais une adaptation potentiellement moins précise.

Ce tutoriel utilise un niveau de priorité LoRA de 4. Dans la pratique, commencez par un rang relativement faible (par exemple, 4, 8 ou 16). Cette approche est efficace pour les tests. Entraînez votre modèle avec ce classement et évaluez l'amélioration des performances pour votre tâche. Augmentez progressivement le classement dans les essais suivants pour voir si cela améliore encore les performances.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Notez que l'activation de LoRA réduit considérablement le nombre de paramètres enregistrables (de 2,6 milliards à 2,9 millions).

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

Remarque concernant l'ajustement fin à précision mixte sur les GPU NVIDIA

Nous vous recommandons d'utiliser une précision complète pour l'ajustement. Lorsque vous effectuez des réglages sur des GPU NVIDIA, notez que vous pouvez utiliser la précision mixte (keras.mixed_precision.set_global_policy('mixed_bfloat16')) pour accélérer l'entraînement avec un impact minimal sur la qualité. L'ajustement de précision mixte consomme plus de mémoire et n'est donc utile que sur les GPU plus volumineux.

Pour l'inférence, la demi-précision (keras.config.set_floatx("bfloat16")) fonctionne et économise de la mémoire, tandis que la précision mixte n'est pas applicable.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Inférence après ajustement

Une fois le paramétrage précis effectué, les réponses suivent les instructions fournies dans l'invite.

Requête de voyage en Europe

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

Le modèle recommande désormais des lieux à visiter en Europe.

Invite de photosynthèse ELI5

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

Le modèle explique à présent la photosynthèse de façon plus simple.

Notez qu'à des fins de démonstration, ce tutoriel affine le modèle sur un petit sous-ensemble de l'ensemble de données pour une seule époque et avec une valeur de classement LoRA faible. Pour obtenir de meilleures réponses du modèle affiné, vous pouvez tester les éléments suivants:

  1. Augmenter la taille de l'ensemble de données d'affinage
  2. Entraînement pour plus d'étapes (époques)
  3. Définir un niveau de priorité LoRA plus élevé
  4. Modifier les valeurs des hyperparamètres tels que learning_rate et weight_decay.

Résumé et étapes suivantes

Ce tutoriel a abordé l'ajustement fin de LoRA sur un modèle Gemma à l'aide de KerasNLP. Consultez ensuite les documents suivants: