مدل های جما را در Keras با استفاده از LoRA تنظیم کنید

مشاهده در ai.google.dev در Google Colab اجرا شود در Vertex AI باز کنید مشاهده منبع در GitHub

بررسی اجمالی

Gemma خانواده ای از مدل های سبک وزن و مدرن است که از همان تحقیقات و فناوری استفاده شده برای ایجاد مدل های Gemini ساخته شده است.

نشان داده شده است که مدل های زبان بزرگ (LLM) مانند Gemma در انواع وظایف NLP موثر هستند. یک LLM ابتدا بر روی مجموعه بزرگی از متن به صورت خود نظارتی از قبل آموزش داده می شود. پیش‌آموزش به LLMها کمک می‌کند تا دانش عمومی، مانند روابط آماری بین کلمات را بیاموزند. سپس یک LLM را می توان با داده های دامنه خاص برای انجام وظایف پایین دستی (مانند تجزیه و تحلیل احساسات) تنظیم کرد.

LLMها از نظر اندازه بسیار بزرگ هستند (پارامترها در حد میلیاردها). تنظیم دقیق کامل (که تمام پارامترهای مدل را به روز می کند) برای اکثر برنامه ها مورد نیاز نیست زیرا مجموعه داده های تنظیم دقیق معمولی نسبتاً کوچکتر از مجموعه داده های قبل از آموزش هستند.

انطباق با رتبه پایین (LoRA) یک تکنیک تنظیم دقیق است که تعداد پارامترهای قابل آموزش برای کارهای پایین دستی را با انجماد وزن های مدل و درج تعداد کمتری وزنه های جدید در مدل به میزان زیادی کاهش می دهد. این باعث می شود که آموزش با LoRA بسیار سریعتر و حافظه کارآمدتر باشد، و وزن مدل کوچکتر (چند صد مگابایت) تولید می شود، همه اینها با حفظ کیفیت خروجی های مدل.

این آموزش شما را با استفاده از KerasNLP برای انجام تنظیم دقیق LoRA بر روی یک مدل Gemma 2B با استفاده از مجموعه داده Databricks Dolly 15k راهنمایی می کند. این مجموعه داده شامل 15000 جفت اعلان / پاسخ با کیفیت بالا است که به طور خاص برای تنظیم دقیق LLM طراحی شده است.

برپایی

به Gemma دسترسی پیدا کنید

برای تکمیل این آموزش، ابتدا باید دستورالعمل‌های راه‌اندازی را در Gemma setup تکمیل کنید. دستورالعمل های راه اندازی Gemma به شما نشان می دهد که چگونه کارهای زیر را انجام دهید:

  • در kaggle.com به Gemma دسترسی پیدا کنید.
  • یک زمان اجرا Colab با منابع کافی برای اجرای مدل Gemma 2B انتخاب کنید.
  • نام کاربری و کلید API Kaggle را ایجاد و پیکربندی کنید.

پس از تکمیل تنظیمات Gemma، به بخش بعدی بروید، جایی که متغیرهای محیطی را برای محیط Colab خود تنظیم خواهید کرد.

زمان اجرا را انتخاب کنید

برای تکمیل این آموزش، باید یک زمان اجرا Colab با منابع کافی برای اجرای مدل Gemma داشته باشید. در این مورد، می توانید از یک GPU T4 استفاده کنید:

  1. در سمت راست بالای پنجره Colab، ▾ ( گزینه های اتصال اضافی ) را انتخاب کنید.
  2. تغییر نوع زمان اجرا را انتخاب کنید.
  3. در بخش شتاب دهنده سخت افزار ، GPU T4 را انتخاب کنید.

کلید API خود را پیکربندی کنید

برای استفاده از Gemma، باید نام کاربری Kaggle و یک کلید Kaggle API ارائه دهید.

برای ایجاد یک کلید Kaggle API، به تب Account پروفایل کاربری Kaggle خود بروید و Create New Token را انتخاب کنید. با این کار دانلود فایل kaggle.json حاوی اطلاعات کاربری API شما راه اندازی می شود.

در Colab، Secrets (🔑) را در قسمت سمت چپ انتخاب کنید و نام کاربری Kaggle و کلید Kaggle API را اضافه کنید. نام کاربری خود را با نام KAGGLE_USERNAME و کلید API خود را با نام KAGGLE_KEY ذخیره کنید.

تنظیم متغیرهای محیطی

متغیرهای محیطی را برای KAGGLE_USERNAME و KAGGLE_KEY تنظیم کنید.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

وابستگی ها را نصب کنید

Keras، KerasNLP و سایر وابستگی ها را نصب کنید.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

یک باطن انتخاب کنید

Keras یک API یادگیری عمیق چند چارچوبی و سطح بالا است که برای سادگی و سهولت استفاده طراحی شده است. با استفاده از Keras 3، می‌توانید گردش‌های کاری را روی یکی از سه Backend اجرا کنید: TensorFlow، JAX یا PyTorch.

برای این آموزش، Backend را برای JAX پیکربندی کنید.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

بسته های وارداتی

Keras و KerasNLP را وارد کنید.

import keras
import keras_nlp

بارگذاری مجموعه داده

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-02-21 16:01:22--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 65.8.178.118, 65.8.178.12, 65.8.178.27, ...
Connecting to huggingface.co (huggingface.co)|65.8.178.118|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1708790483&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwODc5MDQ4M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=BwdEM1fYy7BYkObmc2q94IKmK36Yf4TPP2cKpS9rCxXZXsl65Rvo1dMcCT1rh1pWYRviT64m50aY%7EMV6yZX58OxVJhcVL7A9lsoAJIZfLea6NeZya3Vfd5h%7EhGTD68Iu%7EJl%7EQjzdaVzj70%7E52tBkmVK3N89W7GUeLZC1p4L8iADTLUEEn80fED-kkzcq4lAxN7rKxBMhqJXgmChxbUP0%7EQEa5AuqZFM7WIMCdy6J368digPnIr4ReHNm1VOEjh5qKNwYBuUXqfxU%7EfiBLFHFzDKSIqQw6Bn0B01b2E2CmwFdAd9HndByEmzfJfcs1yhMrbaxVcPCGay5VcRS3U2-5g__&Key-Pair-Id=KVTP0A1DKRTAX [following]
--2024-02-21 16:01:23--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1708790483&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwODc5MDQ4M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=BwdEM1fYy7BYkObmc2q94IKmK36Yf4TPP2cKpS9rCxXZXsl65Rvo1dMcCT1rh1pWYRviT64m50aY%7EMV6yZX58OxVJhcVL7A9lsoAJIZfLea6NeZya3Vfd5h%7EhGTD68Iu%7EJl%7EQjzdaVzj70%7E52tBkmVK3N89W7GUeLZC1p4L8iADTLUEEn80fED-kkzcq4lAxN7rKxBMhqJXgmChxbUP0%7EQEa5AuqZFM7WIMCdy6J368digPnIr4ReHNm1VOEjh5qKNwYBuUXqfxU%7EfiBLFHFzDKSIqQw6Bn0B01b2E2CmwFdAd9HndByEmzfJfcs1yhMrbaxVcPCGay5VcRS3U2-5g__&Key-Pair-Id=KVTP0A1DKRTAX
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 108.157.162.27, 108.157.162.99, 108.157.162.58, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|108.157.162.27|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  64.0MB/s    in 0.2s    

2024-02-21 16:01:23 (64.0 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

داده ها را از قبل پردازش کنید. این آموزش از زیر مجموعه ای از 1000 مثال آموزشی برای اجرای سریعتر نوت بوک استفاده می کند. استفاده از داده های آموزشی بیشتر را برای تنظیم دقیق با کیفیت بالاتر در نظر بگیرید.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

مدل بارگذاری

KerasNLP پیاده سازی بسیاری از معماری های مدل محبوب را ارائه می دهد. در این آموزش، یک مدل با استفاده از GemmaCausalLM ، یک مدل Gemma سرتاسر برای مدل‌سازی زبان علی ایجاد می‌کنید. یک مدل زبان علی، نشانه بعدی را بر اساس نشانه های قبلی پیش بینی می کند.

مدل را با استفاده از متد from_preset ایجاد کنید:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

متد from_preset مدل را از یک معماری و وزن از پیش تعیین شده نمونه سازی می کند. در کد بالا، رشته "gemma_2b_en" معماری از پیش تعیین شده را مشخص می کند - یک مدل Gemma با 2 میلیارد پارامتر.

استنتاج قبل از تنظیم دقیق

در این بخش، مدل را با اعلان های مختلف پرس و جو می کنید تا ببینید چگونه پاسخ می دهد.

درخواست سفر اروپا

برای پیشنهادات در مورد اقداماتی که در سفر به اروپا باید انجام دهید، مدل را جویا شوید.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
It's easy, you just need to follow these steps:

First you must book your trip with a travel agency.
Then you must choose a country and a city.
Next you must choose your hotel, your flight, and your travel insurance
And last you must pack for your trip.
 


What are the benefits of a travel agency?

Response:
Travel agents have the best prices, they know how to negotiate and they can find deals that you won't find on your own.

What are the disadvantages of a travel agency?

Response:
Travel agents are not as flexible as you would like. If you need to change your travel plans last minute, they may charge you a fee for that.
 


How do I choose a travel agency?

Response:
There are a few things you can do to choose the right travel agent. First, check to see if they are accredited by the Better Business Bureau. Second, check their website and see what kind of information they offer. Third, look at their reviews online to see what other people have said about their experiences with them.

How does a travel agency make money?

این مدل با نکات کلی در مورد نحوه برنامه ریزی یک سفر پاسخ می دهد.

درخواست فتوسنتز ELI5

از مدل بخواهید که فتوسنتز را با عباراتی به اندازه کافی ساده برای درک یک کودک 5 ساله توضیح دهد.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants use light energy and carbon dioxide to make sugar and oxygen. This is a simple chemical change because the chemical bonds in the sugar and oxygen are unchanged. Plants also release oxygen during photosynthesis.

Instruction:
Explain how photosynthesis is an example of chemical change.

Response:
Photosynthesis is a chemical reaction that produces oxygen and sugar.

Instruction:
Explain how plants make their own food.

Response:
Plants use energy from sunlight to make sugar and oxygen during photosynthesis.

Instruction:
Explain how the chemical change in a plant during photosynthesis can be described as an example of a chemical reaction.

Response:
Photosynthesis is a chemical change that results in the formation of sugar from carbon dioxide, water, and energy from sunlight.

Instruction:
Explain the role of chlorophyll in plant photosynthesis.

Response:
Chlorophyll is a green pigment found in leaves that traps sunlight energy and helps convert carbon dioxide into food for the plant.

Instruction:
Explain how plants absorb and use sunlight energy to make sugar and oxygen in photosynthesis, and how they release oxygen during the process.

Response:
Plants capture sunlight energy through their leaves and use it

پاسخ مدل حاوی کلماتی است که ممکن است برای کودک آسان نباشد مانند کلروفیل.

تنظیم دقیق LoRA

برای دریافت پاسخ‌های بهتر از مدل، مدل را با انطباق رتبه پایین (LoRA) با استفاده از مجموعه داده Databricks Dolly 15k تنظیم دقیق کنید.

رتبه LoRA ابعاد ماتریس های قابل آموزش را تعیین می کند که به وزن های اصلی LLM اضافه می شوند. بیان و دقت تنظیمات تنظیم دقیق را کنترل می کند.

رتبه بالاتر به این معنی است که تغییرات با جزئیات بیشتر امکان پذیر است، اما همچنین به معنی پارامترهای قابل آموزش بیشتر است. رتبه پایین تر به معنای سربار محاسباتی کمتر، اما به طور بالقوه انطباق دقیق کمتر است.

این آموزش از رتبه LoRA 4 استفاده می کند. در عمل، با یک رتبه نسبتاً کوچک (مانند 4، 8، 16) شروع کنید. این از نظر محاسباتی برای آزمایش کارآمد است. مدل خود را با این رتبه آموزش دهید و بهبود عملکرد را در کار خود ارزیابی کنید. به تدریج رتبه را در آزمایش‌های بعدی افزایش دهید و ببینید که آیا این کار عملکرد را بیشتر می‌کند یا خیر.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

توجه داشته باشید که فعال کردن LoRA تعداد پارامترهای قابل آموزش را به میزان قابل توجهی کاهش می دهد (از 2.5 میلیارد به 1.3 میلیون).

# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 1524s 1s/step - loss: 0.4591 - sparse_categorical_accuracy: 0.5230
<keras.src.callbacks.history.History at 0x7ca3a01701c0>

نکته ای در مورد تنظیم دقیق ترکیبی در پردازنده های گرافیکی NVIDIA

دقت کامل برای تنظیم دقیق توصیه می شود. هنگام تنظیم دقیق پردازنده‌های گرافیکی NVIDIA، توجه داشته باشید که می‌توانید از دقت ترکیبی ( keras.mixed_precision.set_global_policy('mixed_bfloat16') ) برای سرعت بخشیدن به آموزش با حداقل تأثیر بر کیفیت آموزش استفاده کنید. تنظیم دقیق ترکیبی حافظه بیشتری مصرف می کند، بنابراین فقط در پردازنده های گرافیکی بزرگتر مفید است.

برای استنباط، نیم دقت ( keras.config.set_floatx("bfloat16") ) کار می کند و حافظه را ذخیره می کند در حالی که دقت ترکیبی قابل اعمال نیست.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

استنتاج پس از تنظیم دقیق

پس از تنظیم دقیق، پاسخ ها از دستورالعمل ارائه شده در اعلان پیروی می کنند.

درخواست سفر اروپا

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have the time, I would visit London, Paris, Rome, and Berlin. If you're in London, you have to visit Buckingham Palace. If you're in Paris, you have to visit Notre Dame and the Eiffel Tower. If you're in Rome, you have to visit the Coliseum. If you're in Berlin, you have to visit the Brandenburg Gate.

این مدل اکنون مکان هایی را برای بازدید در اروپا توصیه می کند.

درخواست فتوسنتز ELI5

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is when a plant uses sunlight to make energy. The plants use carbon dioxide and water to make sugar and oxygen. This sugar is used by the plant to make food and the oxygen that is made is released into the air. The plant also releases energy that can then be used by the plant or animal that is using it.

این مدل اکنون فتوسنتز را به زبان ساده‌تر توضیح می‌دهد.

توجه داشته باشید که برای اهداف نمایشی، این آموزش مدل را در زیر مجموعه کوچکی از مجموعه داده فقط برای یک دوره و با مقدار رتبه LoRA پایین تنظیم می‌کند. برای دریافت پاسخ‌های بهتر از مدل تنظیم‌شده، می‌توانید موارد زیر را آزمایش کنید:

  1. افزایش اندازه مجموعه داده تنظیم دقیق
  2. آموزش مراحل بیشتر (دوران)
  3. تنظیم یک رتبه LoRA بالاتر
  4. اصلاح مقادیر فراپارامتر مانند learning_rate و weight_decay .

خلاصه و مراحل بعدی

این آموزش تنظیم دقیق LoRA را بر روی یک مدل Gemma با استفاده از KerasNLP پوشش می دهد. در ادامه اسناد زیر را بررسی کنید: