Ottimizzare i modelli Gemma in Keras utilizzando LoRA

Visualizza su ai.google.dev Esegui in Google Colab Apri in Vertex AI Visualizza il codice sorgente su GitHub

Panoramica

Gemma è una famiglia di modelli aperti leggeri e all'avanguardia creati sulla base della stessa ricerca e tecnologia utilizzata per creare i modelli Gemini.

I modelli linguistici di grandi dimensioni (LLM) come Gemma si sono dimostrati efficaci in una serie di attività di NLP. Un LLM viene prima preaddestrato su un ampio corpus di testo in modo autosupervisionato. Il preaddestramento aiuta gli LLM ad apprendere conoscenze generali, come le relazioni statistiche tra le parole. Un modello LLM può quindi essere ottimizzato con dati specifici del dominio per eseguire attività a valle (come l'analisi del sentiment).

Gli LLM sono estremamente grandi (parametri nell'ordine di miliardi). L'ottimizzazione completa (che aggiorna tutti i parametri del modello) non è necessaria per la maggior parte delle applicazioni perché i set di dati di ottimizzazione completa tipici sono relativamente molto più piccoli dei set di dati di preaddestramento.

L'adattamento a basso ranking (LoRA) è una tecnica di ottimizzazione fine che riduce notevolmente il numero di parametri addestrabili per le attività a valle bloccando i pesi del modello e inserendo un numero minore di nuovi pesi. In questo modo, l'addestramento con LoRA è molto più veloce ed efficiente in termini di memoria e produce pesi del modello più piccoli (poche centinaia di MB), il tutto mantenendo la qualità degli output del modello.

Questo tutorial illustra come utilizzare KerasNLP per eseguire la messa a punto fine di LoRA su un modello Gemma 2B utilizzando il set di dati Dolly 15k di Databricks. Questo set di dati contiene 15.000 coppie di prompt / risposta di alta qualità generate da persone e progettate specificamente per la messa a punto degli LLM.

Configurazione

Accedere a Gemma

Per completare questo tutorial, devi prima seguire le istruzioni di configurazione riportate nella pagina Configurazione di Gemma. Le istruzioni di configurazione di Gemma mostrano come:

  • Accedi a Gemma su kaggle.com.
  • Seleziona un runtime Colab con risorse sufficienti per eseguire il modello Gemma 2B.
  • Genera e configura un nome utente e una chiave API Kaggle.

Dopo aver completato la configurazione di Gemma, vai alla sezione successiva, dove imposterai le variabili di ambiente per l'ambiente Colab.

Seleziona il runtime

Per completare questo tutorial, devi disporre di un runtime Colab con risorse sufficienti per eseguire il modello Gemma. In questo caso, puoi utilizzare una GPU T4:

  1. Nell'angolo in alto a destra della finestra Colab, seleziona ▾ (Opzioni di connessione aggiuntive).
  2. Seleziona Cambia tipo di runtime.
  3. In Acceleratore hardware, seleziona GPU T4.

Configura la chiave API

Per utilizzare Gemma, devi fornire il tuo nome utente Kaggle e una chiave API Kaggle.

Per generare una chiave API Kaggle, vai alla scheda Account del tuo profilo utente Kaggle e seleziona Crea nuovo token. Verrà attivato il download di un file kaggle.json contenente le tue credenziali API.

In Colab, seleziona Secrets (🔑) nel riquadro a sinistra e aggiungi il tuo nome utente e la chiave API Kaggle. Memorizza il tuo nome utente sotto il nome KAGGLE_USERNAME e la chiave API sotto il nome KAGGLE_KEY.

Imposta le variabili di ambiente

Imposta le variabili di ambiente per KAGGLE_USERNAME e KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installa le dipendenze

Installa Keras, KerasNLP e altre dipendenze.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

Seleziona un servizio di backend

Keras è un'API di deep learning multi-framework ad alto livello progettata per la semplicità e la facilità d'uso. Con Keras 3, puoi eseguire flussi di lavoro su uno dei tre backend: TensorFlow, JAX o PyTorch.

Per questo tutorial, configura il backend per JAX.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Importa pacchetti

Importa Keras e KerasNLP.

import keras
import keras_nlp

Carica set di dati

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

Esegui la preelaborazione dei dati. Questo tutorial utilizza un sottoinsieme di 1000 esempi di addestramento per eseguire il notebook più velocemente. Valuta la possibilità di utilizzare più dati di addestramento per una messa a punto di qualità superiore.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

Carica modello

KerasNLP fornisce implementazioni di molte architetture di modelli molto diffuse. In questo tutorial, creerai un modello utilizzando GemmaCausalLM, un modello Gemma end-to-end per la creazione di modelli linguistici causali. Un modello linguistico causale prevede il token successivo in base ai token precedenti.

Crea il modello utilizzando il metodo from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

Il metodo from_preset esegue l'inizializzazione del modello da un'architettura e pesi preimpostati. Nel codice riportato sopra, la stringa "gemma2_2b_en" specifica l'architettura preimpostata, ovvero un modello Gemma con 2 miliardi di parametri.

Inferenza prima dell'ottimizzazione

In questa sezione, eseguirai query sul modello con vari prompt per vedere come risponde.

Richiesta di viaggio in Europa

Esegui una query sul modello per avere suggerimenti su cosa fare durante un viaggio in Europa.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

Il modello risponde con suggerimenti generici su come pianificare un viaggio.

Prompt ELI5 sulla fotosintesi

Chiedi al modello di spiegare la fotosintesi in termini abbastanza semplici da essere compresi da un bambino di 5 anni.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

La risposta del modello contiene parole che potrebbero non essere facili da capire per un bambino, come clorofilla.

Ottimizzazione LoRA

Per ottenere risposte migliori dal modello, perfezionalo con l'adattamento a basso ranking (LoRA) utilizzando il set di dati Dolly 15k di Databricks.

Il ranking LoRA determina la dimensionalità delle matrici addestrabili che vengono aggiunte ai pesi originali dell'LLM. controlla l'espressività e la precisione delle regolazioni.

Un ranking più elevato comporta la possibilità di apportare modifiche più dettagliate, ma anche parametri più addestrabili. Un ranking più basso significa meno overhead computazionale, ma un adattamento potenzialmente meno preciso.

Questo tutorial utilizza un ranking LoRa di 4. In pratica, inizia con un ranking relativamente piccolo (ad esempio 4, 8, 16). Questo approccio è efficiente dal punto di vista computazionale per la sperimentazione. Addestra il modello con questo ranking e valuta il miglioramento delle prestazioni dell'attività. Aumenta gradualmente il ranking nelle prove successive e verifica se questo migliora ulteriormente le prestazioni.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Si noti che l'abilitazione di LoRA riduce significativamente il numero di parametri addestrabili (da 2,6 miliardi a 2,9 milioni).

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

Nota sulla messa a punto fine con precisione mista sulle GPU NVIDIA

Per la messa a punto consigliamo la precisione completa. Quando esegui il perfezionamento su GPU NVIDIA, tieni presente che puoi utilizzare la precisione mista (keras.mixed_precision.set_global_policy('mixed_bfloat16')) per velocizzare l'addestramento con un impatto minimo sulla qualità dell'addestramento. La messa a punto con precisione mista consuma più memoria, quindi è utile solo su GPU più grandi.

Per l'inferenza, la mezza precisione (keras.config.set_floatx("bfloat16")) funzionerà e risparmierà memoria, mentre la precisione mista non è applicabile.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Interruzione dopo l'ottimizzazione

Dopo l'ottimizzazione, le risposte seguono le istruzioni fornite nel prompt.

Prompt per i viaggi in Europa

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

Il modello ora consiglia i luoghi da visitare in Europa.

Prompt ELI5 sulla fotosintesi

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

Ora il modello spiega la fotosintesi in termini più semplici.

Tieni presente che, a scopo dimostrativo, questo tutorial ottimizza il modello su un piccolo sottoinsieme del set di dati per una sola epoca e con un valore di ranking LoRA basso. Per ottenere risposte migliori dal modello ottimizzato, puoi fare esperimenti con:

  1. Aumento delle dimensioni del set di dati di ottimizzazione fine
  2. Addestramento per più passaggi (epoche)
  3. Impostazione di un ranking LoRA più elevato
  4. Modifica dei valori degli iperparametri, ad esempio learning_rate e weight_decay.

Riepilogo e passaggi successivi

Questo tutorial illustra la messa a punto fine di LoRA su un modello Gemma utilizzando KerasNLP. Dai un'occhiata ai seguenti documenti: