Gemma-Modelle in Keras mit LoRA optimieren

Auf ai.google.dev ansehen In Google Colab ausführen In Vertex AI öffnen Quellcode auf GitHub ansehen

Übersicht

Gemma ist eine Familie leichter, hochmoderner offener Modelle, die auf derselben Forschung und Technologie basieren, die auch für die Erstellung der Gemini-Modelle verwendet werden.

Large Language Models (LLMs) wie Gemma haben sich bei einer Vielzahl von NLP-Aufgaben als effektiv erwiesen. Ein LLM wird zuerst selbstüberwacht mit einem großen Textkorpus trainiert. Durch das Vortraining können LLMs allgemeines Wissen erlernen, z. B. statistische Beziehungen zwischen Wörtern. Ein LLM kann dann mit domänenspezifischen Daten optimiert werden, um nachfolgende Aufgaben wie die Sentimentanalyse auszuführen.

LLMs sind extrem groß (Parameter in der Größenordnung von Milliarden). Eine vollständige Feinabstimmung, bei der alle Parameter im Modell aktualisiert werden, ist für die meisten Anwendungen nicht erforderlich, da typische Datasets für die Feinabstimmung relativ viel kleiner sind als die Datasets für das Vortraining.

Low Rank Adaptation (LoRA) ist eine Methode zur Feinabstimmung, mit der die Anzahl der trainierbaren Parameter für Downstream-Aufgaben erheblich reduziert wird. Dazu werden die Gewichte des Modells eingefroren und eine kleinere Anzahl neuer Gewichte in das Modell eingefügt. Dadurch ist das Training mit LoRA viel schneller und speichereffizienter. Außerdem werden kleinere Modellgewichte (einige hundert MB) erzielt, ohne dass die Qualität der Modellausgaben beeinträchtigt wird.

In dieser Anleitung wird beschrieben, wie Sie mit KerasNLP eine LoRA-Feinabstimmung für ein Gemma 2B-Modell mithilfe des Databricks Dolly 15k-Dataset durchführen. Dieser Datensatz enthält 15.000 hochwertige, von Menschen erstellte Prompt-/Antwortpaare, die speziell für die Feinabstimmung von LLMs entwickelt wurden.

Einrichtung

Zugriff auf Gemma erhalten

Um diese Anleitung abzuschließen, müssen Sie zuerst die Schritte unter Gemma-Einrichtung ausführen. In der Anleitung zur Einrichtung von Gemma erfahren Sie, wie Sie Folgendes tun können:

  • Sie können Gemma unter kaggle.com nutzen.
  • Wählen Sie eine Colab-Laufzeit mit ausreichenden Ressourcen für die Ausführung des Gemma 2B-Modells aus.
  • Generieren und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.

Nachdem Sie die Gemma-Einrichtung abgeschlossen haben, fahren Sie mit dem nächsten Abschnitt fort, in dem Sie Umgebungsvariablen für Ihre Colab-Umgebung festlegen.

Laufzeit auswählen

Für diese Anleitung benötigen Sie eine Colab-Laufzeit mit ausreichenden Ressourcen, um das Gemma-Modell auszuführen. In diesem Fall können Sie eine T4-GPU verwenden:

  1. Wählen Sie rechts oben im Colab-Fenster das Dreipunkt-Menü ▾ (Zusätzliche Verbindungsoptionen) aus.
  2. Wählen Sie Laufzeittyp ändern aus.
  3. Wähle unter Hardwarebeschleuniger die Option T4-GPU aus.

API-Schlüssel konfigurieren

Wenn Sie Gemma verwenden möchten, müssen Sie Ihren Kaggle-Nutzernamen und einen Kaggle API-Schlüssel angeben.

Wenn Sie einen Kaggle API-Schlüssel generieren möchten, rufen Sie den Tab Konto Ihres Kaggle-Nutzerprofils auf und wählen Sie Neues Token erstellen aus. Dadurch wird der Download einer kaggle.json-Datei mit Ihren API-Anmeldedaten ausgelöst.

Wählen Sie in Colab im linken Bereich Secrets (🔑) aus und fügen Sie Ihren Kaggle-Nutzernamen und Kaggle API-Schlüssel hinzu. Speichere deinen Nutzernamen unter dem Namen KAGGLE_USERNAME und deinen API-Schlüssel unter dem Namen KAGGLE_KEY.

Umgebungsvariablen festlegen

Legen Sie Umgebungsvariablen für KAGGLE_USERNAME und KAGGLE_KEY fest.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Abhängigkeiten installieren

Installieren Sie Keras, KerasNLP und andere Abhängigkeiten.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

Backend auswählen

Keras ist eine allgemeine Deep-Learning-API mit mehreren Frameworks, die für Einfachheit und Benutzerfreundlichkeit entwickelt wurde. Mit Keras 3 können Sie Workflows auf einem von drei Back-Ends ausführen: TensorFlow, JAX oder PyTorch.

Konfigurieren Sie für diese Anleitung das Back-End für JAX.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Pakete importieren

Importieren Sie Keras und KerasNLP.

import keras
import keras_nlp

Dataset laden

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

Führen Sie eine Vorverarbeitung der Daten durch. In dieser Anleitung wird eine Teilmenge von 1.000 Trainingsbeispielen verwendet, um das Notebook schneller auszuführen. Für eine bessere Feinabstimmung sollten Sie mehr Trainingsdaten verwenden.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

Modell laden

KerasNLP bietet Implementierungen vieler beliebter Modellarchitekturen. In dieser Anleitung erstellen Sie ein Modell mit GemmaCausalLM, einem End-to-End-Gemma-Modell für die kausale Sprachmodellierung. Ein kausales Sprachmodell sagt das nächste Token anhand der vorherigen Tokens voraus.

Erstellen Sie das Modell mit der Methode from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

Mit der Methode from_preset wird das Modell anhand einer voreingestellten Architektur und Gewichte instanziiert. Im obigen Code gibt der String „gemma2_2b_en“ die voreingestellte Architektur an, also ein Gemma-Modell mit 2 Milliarden Parametern.

Inferenz vor der Abstimmung

In diesem Abschnitt fragen Sie das Modell mit verschiedenen Prompts ab, um zu sehen, wie es reagiert.

Aufforderung zu Fahrten in Europa

Fragen Sie das Modell nach Vorschlägen für eine Reise nach Europa.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

Das Modell antwortet mit allgemeinen Tipps zur Reiseplanung.

ELI5 Photosynthesis Prompt

Bitten Sie das Modell, die Photosynthese in einer Sprache zu erklären, die für ein fünfjähriges Kind verständlich ist.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

Die Modellantwort enthält Wörter, die für ein Kind möglicherweise nicht leicht zu verstehen sind, z. B. Chlorophyll.

LoRA-Abstimmung

Optimieren Sie das Modell mit LoRA (Low Rank Adaptation), um bessere Antworten vom Modell zu erhalten. Verwenden Sie dazu das Databricks Dolly 15k-Dataset.

Der LoRA-Rang bestimmt die Dimensionalität der trainierbaren Matrizen, die zu den ursprünglichen Gewichtungen des LLM addiert werden. Damit wird die Ausdruckskraft und Präzision der Feinabstimmung gesteuert.

Ein höherer Rang bedeutet, dass detailliertere Änderungen möglich sind, aber auch mehr trainierbare Parameter. Ein niedrigerer Rang bedeutet weniger Rechenaufwand, aber möglicherweise eine weniger genaue Anpassung.

In dieser Anleitung wird ein LoRa-Rang von 4 verwendet. In der Praxis sollten Sie mit einem relativ kleinen Rang beginnen (z. B. 4, 8, 16). Das ist für Tests effizient. Trainieren Sie Ihr Modell mit diesem Rang und bewerten Sie die Leistungsverbesserung für Ihre Aufgabe. Erhöhen Sie den Rang in den nachfolgenden Tests schrittweise und prüfen Sie, ob sich die Leistung dadurch weiter verbessert.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Hinweis: Wenn Sie LoRA aktivieren, wird die Anzahl der trainierbaren Parameter erheblich reduziert (von 2,6 Milliarden auf 2,9 Millionen).

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

Hinweis zur Feinabstimmung mit gemischter Genauigkeit auf NVIDIA-GPUs

Für die Feinabstimmung wird die volle Genauigkeit empfohlen. Bei der Feinabstimmung auf NVIDIA-GPUs können Sie die gemischte Genauigkeit (keras.mixed_precision.set_global_policy('mixed_bfloat16')) verwenden, um das Training mit minimalen Auswirkungen auf die Trainingsqualität zu beschleunigen. Die Feinabstimmung mit gemischter Genauigkeit verbraucht mehr Arbeitsspeicher und ist daher nur auf größeren GPUs sinnvoll.

Für die Inferenz funktioniert die Halbpräzision (keras.config.set_floatx("bfloat16")) und spart Arbeitsspeicher, während die gemischte Genauigkeit nicht geeignet ist.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Inferenz nach der Feinabstimmung

Nach der Feinabstimmung folgen die Antworten der Anleitung im Prompt.

Prompt: Europareise

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

Das Modell empfiehlt jetzt Sehenswürdigkeiten in Europa.

ELI5 Photosynthesis Prompt

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

Die Photosynthese wird jetzt in einfacheren Worten erklärt.

Hinweis: In dieser Anleitung wird das Modell zu Demonstrationszwecken nur auf einer kleinen Teilmenge des Datasets für eine einzige Epoche und mit einem niedrigen LoRA-Rangwert optimiert. Um bessere Antworten vom optimierten Modell zu erhalten, können Sie Folgendes ausprobieren:

  1. Größe des Datasets für die Feinabstimmung erhöhen
  2. Training für mehr Schritte (Epochen)
  3. Höheren LoRA-Rang festlegen
  4. Ändern Sie die Hyperparameterwerte wie learning_rate und weight_decay.

Zusammenfassung und nächste Schritte

In dieser Anleitung haben wir die LoRA-Feinabstimmung für ein Gemma-Modell mit KerasNLP behandelt. Sehen Sie sich als Nächstes die folgenden Dokumente an: