Menyesuaikan model Gemma di Keras menggunakan LoRA

Lihat di ai.google.dev Jalankan di Google Colab Buka di Vertex AI Lihat sumber di GitHub

Ringkasan

Gemma adalah sekumpulan model terbuka yang ringan dan canggih, dibangun dari riset dan teknologi yang digunakan untuk membuat model Gemini.

Model Bahasa Besar (LLM) seperti Gemma telah terbukti efektif dalam berbagai tugas NLP. LLM pertama-tama dilatih sebelumnya menggunakan korpus teks besar dengan cara yang diawasi sendiri. Pra-pelatihan membantu LLM mempelajari pengetahuan umum, seperti hubungan statistik antar-kata. LLM kemudian dapat di-fine-tune dengan data khusus domain untuk melakukan tugas downstream (seperti analisis sentimen).

LLM memiliki ukuran yang sangat besar (parameter dalam urutan miliaran). Penyesuaian penuh (yang memperbarui semua parameter dalam model) tidak diperlukan untuk sebagian besar aplikasi karena set data penyesuaian umum relatif jauh lebih kecil daripada set data pra-pelatihan.

Low Rank Adaptation (LoRA) adalah teknik fine-tuning yang sangat mengurangi jumlah parameter yang dapat dilatih untuk tugas downstream dengan membekukan bobot model dan menyisipkan bobot baru dalam jumlah yang lebih kecil ke dalam model. Hal ini membuat pelatihan dengan LoRA jauh lebih cepat dan lebih hemat memori, serta menghasilkan bobot model yang lebih kecil (beberapa ratus MB), sekaligus mempertahankan kualitas output model.

Tutorial ini akan memandu Anda menggunakan KerasNLP untuk melakukan penyesuaian LoRA pada model Gemma 2B menggunakan set data Databricks Dolly 15k. Set data ini berisi 15.000 pasangan perintah / respons berkualitas tinggi yang dibuat manusia dan dirancang khusus untuk meningkatkan kualitas LLM.

Penyiapan

Mendapatkan akses ke Gemma

Untuk menyelesaikan tutorial ini, Anda harus menyelesaikan petunjuk penyiapan terlebih dahulu di Penyiapan Gemma. Petunjuk penyiapan Gemma menunjukkan cara melakukan hal berikut:

  • Dapatkan akses ke Gemma di kaggle.com.
  • Pilih runtime Colab dengan resource yang memadai untuk menjalankan model Gemma 2B.
  • Buat dan konfigurasikan nama pengguna dan kunci API Kaggle.

Setelah menyelesaikan penyiapan Gemma, lanjutkan ke bagian berikutnya, tempat Anda akan menetapkan variabel lingkungan untuk lingkungan Colab.

Pilih runtime

Untuk menyelesaikan tutorial ini, Anda harus memiliki runtime Colab dengan resource yang memadai untuk menjalankan model Gemma. Dalam hal ini, Anda dapat menggunakan GPU T4:

  1. Di kanan atas jendela Colab, pilih ▾ (Opsi koneksi tambahan).
  2. Pilih Ubah jenis runtime.
  3. Di bagian Hardware accelerator, pilih T4 GPU.

Mengonfigurasi kunci API

Untuk menggunakan Gemma, Anda harus memberikan nama pengguna Kaggle dan kunci API Kaggle.

Untuk membuat kunci API Kaggle, buka tab Account di profil pengguna Kaggle Anda, lalu pilih Create New Token. Tindakan ini akan memicu download file kaggle.json yang berisi kredensial API Anda.

Di Colab, pilih Secrets (🔑) di panel kiri, lalu tambahkan nama pengguna Kaggle dan kunci API Kaggle Anda. Simpan nama pengguna Anda dengan nama KAGGLE_USERNAME dan kunci API Anda dengan nama KAGGLE_KEY.

Menetapkan variabel lingkungan

Tetapkan variabel lingkungan untuk KAGGLE_USERNAME dan KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Menginstal dependensi

Instal Keras, KerasNLP, dan dependensi lainnya.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

Pilih backend

Keras adalah API deep learning multi-framework tingkat tinggi yang dirancang untuk kesederhanaan dan kemudahan penggunaan. Dengan Keras 3, Anda dapat menjalankan alur kerja di salah satu dari tiga backend: TensorFlow, JAX, atau PyTorch.

Untuk tutorial ini, konfigurasikan backend untuk JAX.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Mengimpor paket

Mengimpor Keras dan KerasNLP.

import keras
import keras_nlp

Memuat Set Data

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

Lakukan pra-pemrosesan data. Tutorial ini menggunakan subset dari 1.000 contoh pelatihan untuk menjalankan notebook lebih cepat. Pertimbangkan untuk menggunakan lebih banyak data pelatihan untuk penyesuaian yang lebih berkualitas.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

Memuat Model

KerasNLP menyediakan implementasi dari banyak arsitektur model populer. Dalam tutorial ini, Anda akan membuat model menggunakan GemmaCausalLM, yaitu model Gemma menyeluruh untuk pemodelan bahasa kausal. Model bahasa kausal memprediksi token berikutnya berdasarkan token sebelumnya.

Buat model menggunakan metode from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

Metode from_preset membuat instance model dari arsitektur dan bobot preset. Dalam kode di atas, string "gemma2_2b_en" menentukan arsitektur preset — model Gemma dengan 2 miliar parameter.

Inferen sebelum penyesuaian

Di bagian ini, Anda akan mengkueri model dengan berbagai perintah untuk melihat responsnya.

Perintah Perjalanan Eropa

Buat kueri model untuk mendapatkan saran tentang hal yang harus dilakukan dalam perjalanan ke Eropa.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

Model merespons dengan memberikan tips umum tentang cara merencanakan perjalanan.

Perintah Fotosintesis ELI5

Minta model untuk menjelaskan fotosintesis dalam istilah yang cukup sederhana untuk dipahami oleh anak berusia 5 tahun.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

Respons model berisi kata-kata yang mungkin tidak mudah dipahami oleh anak-anak seperti klorofil.

Penyesuaian LoRA

Untuk mendapatkan respons yang lebih baik dari model, sesuaikan model dengan Low Rank Adaptation (LoRA) menggunakan set data Databricks Dolly 15k.

Peringkat LoRA menentukan dimensi matriks yang dapat dilatih yang ditambahkan ke bobot asli LLM. Parameter ini mengontrol ekspresi dan presisi penyesuaian penyesuaian.

Peringkat yang lebih tinggi berarti perubahan yang lebih mendetail dapat dilakukan, tetapi juga berarti lebih banyak parameter yang dapat dilatih. Peringkat yang lebih rendah berarti overhead komputasi yang lebih sedikit, tetapi adaptasi yang berpotensi kurang presisi.

Tutorial ini menggunakan peringkat LoRA 4. Dalam praktiknya, mulailah dengan peringkat yang relatif kecil (seperti 4, 8, 16). Hal ini efisien secara komputasi untuk eksperimen. Latih model Anda dengan peringkat ini dan evaluasi peningkatan performa pada tugas Anda. Tingkatkan peringkat secara bertahap dalam uji coba berikutnya dan lihat apakah hal itu meningkatkan performa lebih lanjut.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Perhatikan bahwa mengaktifkan LoRA akan mengurangi jumlah parameter yang dapat dilatih secara signifikan (dari 2,6 miliar menjadi 2,9 juta).

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

Catatan tentang penyesuaian presisi campuran di GPU NVIDIA

Presisi penuh direkomendasikan untuk fine-tuning. Saat melakukan penyesuaian di GPU NVIDIA, perhatikan bahwa Anda dapat menggunakan presisi campuran (keras.mixed_precision.set_global_policy('mixed_bfloat16')) untuk mempercepat pelatihan dengan efek minimal pada kualitas pelatihan. Penyesuaian presisi campuran memang menggunakan lebih banyak memori sehingga hanya berguna pada GPU yang lebih besar.

Untuk inferensi, presisi setengah (keras.config.set_floatx("bfloat16")) akan berfungsi dan menghemat memori, sedangkan presisi campuran tidak berlaku.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Inferensia setelah penyesuaian

Setelah penyesuaian, respons akan mengikuti petunjuk yang diberikan dalam perintah.

Dialog Perjalanan Eropa

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

Model ini kini merekomendasikan tempat untuk dikunjungi di Eropa.

Perintah Fotosintesis ELI5

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

Model ini kini menjelaskan fotosintesis dalam istilah yang lebih sederhana.

Perhatikan bahwa untuk tujuan demonstrasi, tutorial ini menyesuaikan model pada sebagian kecil set data hanya untuk satu epoch dan dengan nilai peringkat LoRA yang rendah. Untuk mendapatkan respons yang lebih baik dari model yang telah disesuaikan, Anda dapat bereksperimen dengan:

  1. Meningkatkan ukuran set data penyesuaian
  2. Pelatihan untuk lebih banyak langkah (epoch)
  3. Menetapkan peringkat LoRA yang lebih tinggi
  4. Mengubah nilai hyperparameter seperti learning_rate dan weight_decay.

Ringkasan dan langkah berikutnya

Tutorial ini membahas penyesuaian LoRA pada model Gemma menggunakan KerasNLP. Lihat dokumen berikut: