Точная настройка моделей Gemma в Keras с использованием LoRA

Посмотреть на ai.google.dev Запустить в Google Colab Открыть в Vertex AI Посмотреть исходный код на GitHub

Обзор

Gemma — это семейство легких современных открытых моделей, созданных на основе тех же исследований и технологий, которые использовались при создании моделей Gemini.

Было доказано, что модели большого языка (LLM), такие как Gemma, эффективны при решении различных задач НЛП. LLM сначала проходит предварительное обучение на большом массиве текста в режиме самоконтроля. Предварительное обучение помогает магистрам права освоить знания общего назначения, такие как статистические связи между словами. Затем LLM можно настроить с использованием данных, специфичных для предметной области, для выполнения последующих задач (например, анализа настроений).

LLM чрезвычайно велики по размеру (параметры порядка миллиардов). Полная точная настройка (которая обновляет все параметры модели) не требуется для большинства приложений, поскольку типичные наборы данных тонкой настройки относительно намного меньше, чем наборы данных перед обучением.

Адаптация низкого ранга (LoRA) — это метод тонкой настройки, который значительно сокращает количество обучаемых параметров для последующих задач путем замораживания весов модели и добавления в модель меньшего количества новых весов. Это делает обучение с помощью LoRA намного более быстрым и более эффективным с использованием памяти, а также дает меньший вес модели (несколько сотен МБ), сохраняя при этом качество выходных данных модели.

В этом руководстве рассказывается, как использовать KerasNLP для точной настройки LoRA модели Gemma 2B с использованием набора данных Databricks Dolly 15k . Этот набор данных содержит 15 000 высококачественных пар подсказок и ответов, созданных человеком и специально разработанных для точной настройки LLM.

Настраивать

Получить доступ к Джемме

Чтобы выполнить это руководство, вам сначала необходимо выполнить инструкции по настройке на странице настройки Gemma . В инструкциях по настройке Gemma показано, как сделать следующее:

  • Получите доступ к Джемме на kaggle.com .
  • Выберите среду выполнения Colab с достаточными ресурсами для запуска модели Gemma 2B.
  • Создайте и настройте имя пользователя Kaggle и ключ API.

После завершения настройки Gemma перейдите к следующему разделу, где вы установите переменные среды для вашей среды Colab.

Выберите время выполнения

Для выполнения этого руководства вам потребуется среда выполнения Colab с достаточными ресурсами для запуска модели Gemma. В этом случае вы можете использовать графический процессор T4:

  1. В правом верхнем углу окна Colab выберите ▾ ( Дополнительные параметры подключения ).
  2. Выберите Изменить тип среды выполнения .
  3. В разделе «Аппаратный ускоритель» выберите T4 GPU .

Настройте свой ключ API

Чтобы использовать Gemma, вы должны предоставить свое имя пользователя Kaggle и ключ API Kaggle.

Чтобы сгенерировать ключ API Kaggle, перейдите на вкладку «Учетная запись» вашего профиля пользователя Kaggle и выберите « Создать новый токен» . Это приведет к загрузке файла kaggle.json , содержащего ваши учетные данные API.

В Colab выберите «Секреты» (🔑) на левой панели и добавьте свое имя пользователя Kaggle и ключ API Kaggle. Сохраните свое имя пользователя под именем KAGGLE_USERNAME и ключ API под именем KAGGLE_KEY .

Установить переменные среды

Установите переменные среды для KAGGLE_USERNAME и KAGGLE_KEY .

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Установить зависимости

Установите Keras, KerasNLP и другие зависимости.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

Выберите серверную часть

Keras — это высокоуровневый многоплатформенный API глубокого обучения, разработанный для простоты и удобства использования. Используя Keras 3, вы можете запускать рабочие процессы на одном из трех бэкэндов: TensorFlow, JAX или PyTorch.

В этом руководстве настройте серверную часть для JAX.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Импортировать пакеты

Импортируйте Keras и KerasNLP.

import keras
import keras_nlp

Загрузить набор данных

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

Предварительная обработка данных. В этом руководстве используется подмножество из 1000 обучающих примеров для более быстрого выполнения блокнота. Рассмотрите возможность использования большего количества обучающих данных для более качественной точной настройки.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

Загрузить модель

KerasNLP предоставляет реализации многих популярных модельных архитектур . В этом руководстве вы создадите модель, используя GemmaCausalLM , комплексную модель Gemma для моделирования причинного языка. Модель причинного языка прогнозирует следующий токен на основе предыдущих токенов.

Создайте модель, используя метод from_preset :

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

Метод from_preset создает экземпляр модели на основе предустановленной архитектуры и весов. В приведенном выше коде строка «gemma2_2b_en» определяет предустановленную архитектуру — модель Gemma с 2 миллиардами параметров.

Выводы перед точной настройкой

В этом разделе вы будете запрашивать модель с помощью различных подсказок, чтобы увидеть, как она отреагирует.

Подсказка о поездке в Европу

Запросите у модели предложения о том, чем заняться во время поездки в Европу.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

Модель отвечает общими советами о том, как спланировать поездку.

ELI5 Фотосинтез Подсказка

Предложите модели объяснить фотосинтез достаточно простыми словами, чтобы их мог понять пятилетний ребенок.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

Модельный ответ содержит слова, которые ребенку может быть нелегко понять, например хлорофилл.

LoRA Тонкая настройка

Чтобы получить более точные ответы от модели, настройте модель с помощью низкоранговой адаптации (LoRA), используя набор данных Databricks Dolly 15k.

Ранг LoRA определяет размерность обучаемых матриц, которые добавляются к исходным весам LLM. Он контролирует выразительность и точность тонкой настройки.

Более высокий ранг означает, что возможны более детальные изменения, но также означает и более обучаемые параметры. Более низкий ранг означает меньшие вычислительные затраты, но потенциально менее точную адаптацию.

В этом руководстве используется ранг LoRA, равный 4. На практике начните с относительно небольшого ранга (например, 4, 8, 16). Это вычислительно эффективно для экспериментов. Обучите свою модель с этим рангом и оцените улучшение производительности при выполнении вашей задачи. Постепенно повышайте ранг в последующих испытаниях и посмотрите, повысит ли это производительность.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Обратите внимание, что включение LoRA значительно сокращает количество обучаемых параметров (с 2,6 миллиарда до 2,9 миллиона).

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

Примечание о точной настройке смешанной точности на графических процессорах NVIDIA.

Для точной настройки рекомендуется полная точность. При точной настройке графических процессоров NVIDIA обратите внимание, что вы можете использовать смешанную точность ( keras.mixed_precision.set_global_policy('mixed_bfloat16') ) для ускорения обучения с минимальным влиянием на качество обучения. Точная настройка смешанной точности потребляет больше памяти, поэтому полезна только на больших графических процессорах.

Для вывода, половинная точность ( keras.config.set_floatx("bfloat16") ) будет работать и экономить память, тогда как смешанная точность неприменима.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Вывод после тонкой настройки

После тонкой настройки ответы будут следовать инструкциям, приведенным в подсказке.

Подсказка о поездке в Европу

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

Модель теперь рекомендует места для посещения в Европе.

ELI5 Фотосинтез Подсказка

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

Теперь модель объясняет фотосинтез более простыми словами.

Обратите внимание, что в демонстрационных целях в этом руководстве модель настраивается на небольшом подмножестве набора данных всего за одну эпоху и с низким значением ранга LoRA. Чтобы получить лучшие ответы от точно настроенной модели, вы можете поэкспериментировать с:

  1. Увеличение размера набора данных тонкой настройки
  2. Обучение большему количеству шагов (эпох)
  3. Установка более высокого ранга LoRA
  4. Изменение значений гиперпараметров, таких как learning_rate и weight_decay .

Резюме и следующие шаги

В этом руководстве рассматривается тонкая настройка LoRA модели Gemma с использованием KerasNLP. Далее ознакомьтесь со следующими документами: