在 ai.google.dev 上查看 | 在 Google Colab 中執行 | 在 Vertex AI 中開啟 | 在 GitHub 上查看來源 |
總覽
Gemma 是一系列先進的輕量級開放式模型,採用與建立 Gemini 模型時相同的研究成果和技術。
Gemma 這類大型語言模型 (LLM) 已證實可有效執行各種 NLP 工作。LLM 會先以自主監督的大量文字資料庫預先訓練,預先訓練有助於 LLM 學習通用知識,例如字詞之間的統計關係。接著,您可以使用特定領域的資料微調大型語言模型,執行後續工作 (例如情緒分析)。
LLM 的規模非常龐大 (參數以數十億人為單位)。大多數應用程式不需要進行完整微調 (更新模型中的所有參數),因為一般微調資料集相較於預先訓練資料集,相對較小。
低秩調整 (LoRA) 是一種微調技巧,可透過凍結模型權重並在模型中插入較少數量的新權重,大幅減少下游任務的可訓練參數數量。這樣一來,使用 LoRA 訓練的速度會大幅提升,記憶體效率也會提高,產生的模型權重也會縮小 (幾百 MB),同時維持模型輸出的品質。
本教學課程將逐步引導您使用 KerasNLP,使用 Databricks Dolly 15k 資料集,對 Gemma 2B 模型進行 LoRA 微調。這份資料集包含 15,000 組由人類產生的高品質提示 / 回覆組合,專門用於微調 LLM。
設定
取得 Gemma 存取權
如要完成本教學課程,您必須先前往 Gemma 設定頁面完成設定。Gemma 設定操作說明會說明如何執行下列操作:
- 前往 kaggle.com 取得 Gemma 存取權。
- 請選取具有足夠資源來執行 Gemma 2B 模型的 Colab 執行階段。
- 產生並設定 Kaggle 使用者名稱和 API 金鑰。
完成 Gemma 設定後,請繼續閱讀下一節,瞭解如何設定 Colab 環境的環境變數。
選取執行階段
如要完成本教學課程,您必須擁有 Colab 執行階段,並具備足夠的資源來執行 Gemma 模型。在這種情況下,您可以使用 T4 GPU:
- 在 Colab 視窗的右上方,選取 ▾ (其他連結選項)。
- 選取「變更執行階段類型」。
- 在「硬體加速器」下方,選取「T4 GPU」。
設定 API 金鑰
如要使用 Gemma,您必須提供 Kaggle 使用者名稱和 Kaggle API 金鑰。
如要產生 Kaggle API 金鑰,請前往 Kaggle 使用者個人資料的「帳戶」分頁,然後選取「建立新的權杖」。這會觸發下載內含 API 憑證的 kaggle.json
檔案。
在 Colab 中,選取左側窗格中的「Secrets」 (🔑?),然後新增 Kaggle 使用者名稱和 Kaggle API 金鑰。將使用者名稱儲存為 KAGGLE_USERNAME
,將 API 金鑰儲存為 KAGGLE_KEY
。
設定環境變數
設定 KAGGLE_USERNAME
和 KAGGLE_KEY
的環境變數。
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
安裝依附元件
安裝 Keras、KerasNLP 和其他依附元件。
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"
選取後端
Keras 是高階的多架構深度學習 API,專為簡化使用而設計。使用 Keras 3 時,您可以在三個後端 (TensorFlow、JAX 或 PyTorch) 中選擇一個執行工作流程。
在本教學課程中,請設定 JAX 的後端。
os.environ["KERAS_BACKEND"] = "jax" # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
匯入套件
匯入 Keras 和 KerasNLP。
import keras
import keras_nlp
載入資料集
wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39-- https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ... Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected. HTTP request sent, awaiting response... 302 Found Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following] --2024-07-31 01:56:39-- https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ... Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 13085339 (12M) [text/plain] Saving to: ‘databricks-dolly-15k.jsonl’ databricks-dolly-15 100%[===================>] 12.48M 73.7MB/s in 0.2s 2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]
預先處理資料。本教學課程會使用 1000 個訓練範例的子集,以便加快執行 Notebook 的速度。建議您使用更多訓練資料,以便進行更精確的微調。
import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
for line in file:
features = json.loads(line)
# Filter out examples with context, to keep it simple.
if features["context"]:
continue
# Format the entire example as a single string.
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
data.append(template.format(**features))
# Only use 1000 training examples, to keep it fast.
data = data[:1000]
載入模型
KerasNLP 提供許多熱門模型架構的實作方式。在本教學課程中,您將使用 GemmaCausalLM
建立模型,這是用於因果語言建模的端對端 Gemma 模型。因果語言模型會根據先前的符記預測下一個符記。
使用 from_preset
方法建立模型:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()
from_preset
方法會根據預設的架構和權重將模型例項化。在上述程式碼中,「gemma2_2b_en」字串會指定預設架構,也就是含有 20 億個參數的 Gemma 模型。
在微調前進行推論
在本節中,您將使用各種提示查詢模型,瞭解模型的回應方式。
歐洲旅遊提示
查詢模型,取得前往歐洲旅遊時可做的建議活動。
prompt = template.format(
instruction="What should I do on a trip to Europe?",
response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction: What should I do on a trip to Europe? Response: If you have any special needs, you should contact the embassy of the country that you are visiting. You should contact the embassy of the country that I will be visiting. What are my responsibilities when I go on a trip? Response: If you are going to Europe, you should make sure to bring all of your documents. If you are going to Europe, make sure that you have all of your documents. When do you travel abroad? Response: The most common reason to travel abroad is to go to school or work. The most common reason to travel abroad is to work. How can I get a visa to Europe? Response: If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy. If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy. When should I go to Europe? Response: You should go to Europe when the weather is nice. You should go to Europe when the weather is bad. How can I make a reservation for a trip?
模型會回覆有關如何規劃行程的一般提示。
ELI5 光合作用提示
請模型以 5 歲兒童能理解的簡單用語,說明光合作用。
prompt = template.format(
instruction="Explain the process of photosynthesis in a way that a child could understand.",
response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction: Explain the process of photosynthesis in a way that a child could understand. Response: Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis. Instruction: What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration? Response: The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide. Instruction: Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration. Response: Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen. Instruction: How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell? Response: Photosynthesis occurs in the cells of a plant. The purpose of
模型回覆內容包含「葉綠素」等兒童可能不易理解的字詞。
LoRA 微調
為了從模型取得更準確的回應,請使用 Databricks Dolly 15k 資料集,使用低評級調整 (LoRA) 來微調模型。
LoRA 秩會決定可訓練矩陣的維度,這些矩陣會加到 LLM 的原始權重。控制微調功能的表現方式和精確度。
排名越高,可進行的細微變更就越多,但也代表可訓練的參數越多。排名越低,運算成本越低,但可能導致調整不夠精確。
本教學課程使用的 LoRA 等級為 4。實務上,請從相對較低的排名開始 (例如 4、8、16)。這可為實驗提供高效的運算效能。使用這個排名訓練模型,並評估任務的效能提升情形。在後續試驗中逐步提高排名,看看是否能進一步提升成效。
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
請注意,啟用 LoRA 會大幅減少可訓練參數的數量 (從 26 億減少至 290 萬參數)。
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251 <keras.src.callbacks.history.History at 0x799d04393c40>
關於在 NVIDIA GPU 上使用混合精確度微調的注意事項
建議您使用完整精確度進行微調。在 NVIDIA GPU 上進行微調時,請注意,您可以使用混合精確度 (keras.mixed_precision.set_global_policy('mixed_bfloat16')
) 加快訓練速度,同時盡可能不影響訓練品質。混合精確度精細調整確實會耗用更多記憶體,因此只適用於較大的 GPU。
在推論時,半精確度 (keras.config.set_floatx("bfloat16")
) 會運作並節省記憶體,而混合精確度則不適用。
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')
微調後的推論
精細調整後,回應會依照提示中的指示進行。
歐洲旅遊提示
prompt = template.format(
instruction="What should I do on a trip to Europe?",
response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction: What should I do on a trip to Europe? Response: When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.
模型現在會推薦歐洲景點。
ELI5 光合作用提示
prompt = template.format(
instruction="Explain the process of photosynthesis in a way that a child could understand.",
response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction: Explain the process of photosynthesis in a way that a child could understand. Response: The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.
模型現在會以簡單易懂的用語說明光合作用。
請注意,為了示範目的,本教學課程只會在資料集的一小部分上,針對單一 epoch 和低 LoRA 排名值進行模型精調。如要透過經過微調的模型取得更實用的回應,您可以嘗試以下做法:
- 增加微調資料集的大小
- 訓練更多步驟 (週期)
- 設定較高的 LoRA 等級
- 修改
learning_rate
和weight_decay
等超參數值。
總結與後續步驟
本教學課程說明如何使用 KerasNLP 對 Gemma 模型進行微調。接下來,請參閱下列文件: