Точная настройка моделей Gemma в Keras с использованием LoRA

Посмотреть на ai.google.dev Запустить в Google Colab Открыть в Vertex AI Посмотреть исходный код на GitHub

Обзор

Gemma — это семейство легких современных открытых моделей, созданных на основе тех же исследований и технологий, которые использовались при создании моделей Gemini.

Было доказано, что модели большого языка (LLM), такие как Gemma, эффективны при решении различных задач НЛП. LLM сначала проходит предварительное обучение на большом массиве текста в режиме самоконтроля. Предварительное обучение помогает магистрам права освоить знания общего назначения, такие как статистические связи между словами. Затем LLM можно настроить с использованием данных, специфичных для предметной области, для выполнения последующих задач (например, анализа настроений).

LLM чрезвычайно велики по размеру (параметры порядка миллиардов). Полная точная настройка (которая обновляет все параметры модели) не требуется для большинства приложений, поскольку типичные наборы данных тонкой настройки относительно намного меньше, чем наборы данных перед обучением.

Адаптация низкого ранга (LoRA) — это метод тонкой настройки, который значительно уменьшает количество обучаемых параметров для последующих задач за счет замораживания весов модели и добавления в модель меньшего количества новых весов. Это делает обучение с помощью LoRA намного более быстрым и более эффективным с использованием памяти, а также дает меньший вес модели (несколько сотен МБ), сохраняя при этом качество выходных данных модели.

В этом руководстве рассказывается, как использовать KerasNLP для точной настройки LoRA модели Gemma 2B с использованием набора данных Databricks Dolly 15k . Этот набор данных содержит 15 000 высококачественных пар подсказок и ответов, созданных человеком и специально разработанных для точной настройки LLM.

Настраивать

Получить доступ к Джемме

Чтобы выполнить это руководство, вам сначала необходимо выполнить инструкции по настройке на странице настройки Gemma . В инструкциях по настройке Gemma показано, как сделать следующее:

  • Получите доступ к Джемме на kaggle.com .
  • Выберите среду выполнения Colab с достаточными ресурсами для запуска модели Gemma 2B.
  • Создайте и настройте имя пользователя Kaggle и ключ API.

После завершения настройки Gemma перейдите к следующему разделу, где вы установите переменные среды для вашей среды Colab.

Выберите время выполнения

Для выполнения этого руководства вам потребуется среда выполнения Colab с достаточными ресурсами для запуска модели Gemma. В этом случае вы можете использовать графический процессор T4:

  1. В правом верхнем углу окна Colab выберите ▾ ( Дополнительные параметры подключения ).
  2. Выберите Изменить тип среды выполнения .
  3. В разделе «Аппаратный ускоритель» выберите T4 GPU .

Настройте свой ключ API

Чтобы использовать Gemma, вы должны предоставить свое имя пользователя Kaggle и ключ API Kaggle.

Чтобы сгенерировать ключ API Kaggle, перейдите на вкладку «Учетная запись» вашего профиля пользователя Kaggle и выберите « Создать новый токен» . Это приведет к загрузке файла kaggle.json , содержащего ваши учетные данные API.

В Colab выберите «Секреты» (🔑) на левой панели и добавьте свое имя пользователя Kaggle и ключ API Kaggle. Сохраните свое имя пользователя под именем KAGGLE_USERNAME и ключ API под именем KAGGLE_KEY .

Установить переменные среды

Установите переменные среды для KAGGLE_USERNAME и KAGGLE_KEY .

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Установить зависимости

Установите Keras, KerasNLP и другие зависимости.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

Выберите серверную часть

Keras — это высокоуровневый многоплатформенный API глубокого обучения, разработанный для простоты и удобства использования. Используя Keras 3, вы можете запускать рабочие процессы на одном из трех бэкэндов: TensorFlow, JAX или PyTorch.

В рамках этого руководства настройте серверную часть для JAX.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Импортировать пакеты

Импортируйте Keras и KerasNLP.

import keras
import keras_nlp

Загрузить набор данных

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-02-21 16:01:22--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 65.8.178.118, 65.8.178.12, 65.8.178.27, ...
Connecting to huggingface.co (huggingface.co)|65.8.178.118|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1708790483&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwODc5MDQ4M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=BwdEM1fYy7BYkObmc2q94IKmK36Yf4TPP2cKpS9rCxXZXsl65Rvo1dMcCT1rh1pWYRviT64m50aY%7EMV6yZX58OxVJhcVL7A9lsoAJIZfLea6NeZya3Vfd5h%7EhGTD68Iu%7EJl%7EQjzdaVzj70%7E52tBkmVK3N89W7GUeLZC1p4L8iADTLUEEn80fED-kkzcq4lAxN7rKxBMhqJXgmChxbUP0%7EQEa5AuqZFM7WIMCdy6J368digPnIr4ReHNm1VOEjh5qKNwYBuUXqfxU%7EfiBLFHFzDKSIqQw6Bn0B01b2E2CmwFdAd9HndByEmzfJfcs1yhMrbaxVcPCGay5VcRS3U2-5g__&Key-Pair-Id=KVTP0A1DKRTAX [following]
--2024-02-21 16:01:23--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1708790483&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwODc5MDQ4M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=BwdEM1fYy7BYkObmc2q94IKmK36Yf4TPP2cKpS9rCxXZXsl65Rvo1dMcCT1rh1pWYRviT64m50aY%7EMV6yZX58OxVJhcVL7A9lsoAJIZfLea6NeZya3Vfd5h%7EhGTD68Iu%7EJl%7EQjzdaVzj70%7E52tBkmVK3N89W7GUeLZC1p4L8iADTLUEEn80fED-kkzcq4lAxN7rKxBMhqJXgmChxbUP0%7EQEa5AuqZFM7WIMCdy6J368digPnIr4ReHNm1VOEjh5qKNwYBuUXqfxU%7EfiBLFHFzDKSIqQw6Bn0B01b2E2CmwFdAd9HndByEmzfJfcs1yhMrbaxVcPCGay5VcRS3U2-5g__&Key-Pair-Id=KVTP0A1DKRTAX
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 108.157.162.27, 108.157.162.99, 108.157.162.58, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|108.157.162.27|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  64.0MB/s    in 0.2s    

2024-02-21 16:01:23 (64.0 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

Предварительная обработка данных. В этом руководстве используется подмножество из 1000 обучающих примеров для более быстрого выполнения блокнота. Рассмотрите возможность использования большего количества обучающих данных для более качественной точной настройки.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

Загрузить модель

KerasNLP предоставляет реализации многих популярных модельных архитектур . В этом руководстве вы создадите модель, используя GemmaCausalLM , комплексную модель Gemma для моделирования причинного языка. Модель причинного языка прогнозирует следующий токен на основе предыдущих токенов.

Создайте модель, используя метод from_preset :

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

Метод from_preset создает экземпляр модели на основе предустановленной архитектуры и весов. В приведенном выше коде строка «gemma_2b_en» определяет предустановленную архитектуру — модель Gemma с 2 миллиардами параметров.

Выводы перед точной настройкой

В этом разделе вы будете запрашивать модель с помощью различных подсказок, чтобы увидеть, как она отреагирует.

Подсказка о поездке в Европу

Запросите у модели предложения о том, чем заняться во время поездки в Европу.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
It's easy, you just need to follow these steps:

First you must book your trip with a travel agency.
Then you must choose a country and a city.
Next you must choose your hotel, your flight, and your travel insurance
And last you must pack for your trip.
 


What are the benefits of a travel agency?

Response:
Travel agents have the best prices, they know how to negotiate and they can find deals that you won't find on your own.

What are the disadvantages of a travel agency?

Response:
Travel agents are not as flexible as you would like. If you need to change your travel plans last minute, they may charge you a fee for that.
 


How do I choose a travel agency?

Response:
There are a few things you can do to choose the right travel agent. First, check to see if they are accredited by the Better Business Bureau. Second, check their website and see what kind of information they offer. Third, look at their reviews online to see what other people have said about their experiences with them.

How does a travel agency make money?

Модель отвечает общими советами о том, как спланировать поездку.

ELI5 Фотосинтез Подсказка

Предложите модели объяснить фотосинтез достаточно простыми словами, чтобы их мог понять пятилетний ребенок.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants use light energy and carbon dioxide to make sugar and oxygen. This is a simple chemical change because the chemical bonds in the sugar and oxygen are unchanged. Plants also release oxygen during photosynthesis.

Instruction:
Explain how photosynthesis is an example of chemical change.

Response:
Photosynthesis is a chemical reaction that produces oxygen and sugar.

Instruction:
Explain how plants make their own food.

Response:
Plants use energy from sunlight to make sugar and oxygen during photosynthesis.

Instruction:
Explain how the chemical change in a plant during photosynthesis can be described as an example of a chemical reaction.

Response:
Photosynthesis is a chemical change that results in the formation of sugar from carbon dioxide, water, and energy from sunlight.

Instruction:
Explain the role of chlorophyll in plant photosynthesis.

Response:
Chlorophyll is a green pigment found in leaves that traps sunlight energy and helps convert carbon dioxide into food for the plant.

Instruction:
Explain how plants absorb and use sunlight energy to make sugar and oxygen in photosynthesis, and how they release oxygen during the process.

Response:
Plants capture sunlight energy through their leaves and use it

Модельный ответ содержит слова, которые ребенку может быть нелегко понять, например хлорофилл.

LoRA Тонкая настройка

Чтобы получить более точные ответы от модели, настройте модель с помощью низкоранговой адаптации (LoRA), используя набор данных Databricks Dolly 15k.

Ранг LoRA определяет размерность обучаемых матриц, которые добавляются к исходным весам LLM. Он контролирует выразительность и точность тонкой настройки.

Более высокий ранг означает, что возможны более детальные изменения, но также означает и более обучаемые параметры. Более низкий ранг означает меньшие вычислительные затраты, но потенциально менее точную адаптацию.

В этом руководстве используется ранг LoRA, равный 4. На практике начните с относительно небольшого ранга (например, 4, 8, 16). Это вычислительно эффективно для экспериментов. Обучите свою модель с этим рангом и оцените улучшение производительности при выполнении вашей задачи. Постепенно повышайте ранг в последующих испытаниях и посмотрите, повысит ли это производительность.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Обратите внимание, что включение LoRA значительно сокращает количество обучаемых параметров (с 2,5 миллиарда до 1,3 миллиона).

# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 1524s 1s/step - loss: 0.4591 - sparse_categorical_accuracy: 0.5230
<keras.src.callbacks.history.History at 0x7ca3a01701c0>

Примечание о точной настройке смешанной точности на графических процессорах NVIDIA.

Для точной настройки рекомендуется полная точность. При точной настройке графических процессоров NVIDIA обратите внимание, что вы можете использовать смешанную точность ( keras.mixed_precision.set_global_policy('mixed_bfloat16') ) для ускорения обучения с минимальным влиянием на качество обучения. Точная настройка смешанной точности потребляет больше памяти, поэтому полезна только на больших графических процессорах.

Для вывода, половинная точность ( keras.config.set_floatx("bfloat16") ) будет работать и экономить память, тогда как смешанная точность неприменима.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Вывод после тонкой настройки

После тонкой настройки ответы следуют инструкциям, приведенным в подсказке.

Подсказка о поездке в Европу

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have the time, I would visit London, Paris, Rome, and Berlin. If you're in London, you have to visit Buckingham Palace. If you're in Paris, you have to visit Notre Dame and the Eiffel Tower. If you're in Rome, you have to visit the Coliseum. If you're in Berlin, you have to visit the Brandenburg Gate.

Модель теперь рекомендует места для посещения в Европе.

ELI5 Фотосинтез Подсказка

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is when a plant uses sunlight to make energy. The plants use carbon dioxide and water to make sugar and oxygen. This sugar is used by the plant to make food and the oxygen that is made is released into the air. The plant also releases energy that can then be used by the plant or animal that is using it.

Теперь модель объясняет фотосинтез более простыми словами.

Обратите внимание, что в демонстрационных целях в этом руководстве модель настраивается на небольшом подмножестве набора данных всего за одну эпоху и с низким значением ранга LoRA. Чтобы получить лучшие ответы от точно настроенной модели, вы можете поэкспериментировать с:

  1. Увеличение размера набора данных тонкой настройки
  2. Обучение большему количеству шагов (эпох)
  3. Установка более высокого ранга LoRA
  4. Изменение значений гиперпараметров, таких как learning_rate и weight_decay .

Резюме и следующие шаги

В этом руководстве рассматривается тонкая настройка LoRA модели Gemma с использованием KerasNLP. Далее ознакомьтесь со следующими документами: