Wyświetl na stronie ai.google.dev | Uruchom w Google Colab | Otwórz w Vertex AI | Wyświetl źródło w GitHubie |
Omówienie
Gemma to rodzina lekkich, nowoczesnych modeli otwartych opartych na tych samych badaniach i technologii, które posłużyły do utworzenia modeli Gemini.
Wykazano, że duże modele językowe (LLM), takie jak Gemma, są skuteczne w różnych zadaniach NLP. Najpierw duży model językowy jest wstępnie trenowany z wykorzystaniem dużego zbioru tekstów w trybie samokontroli. Dzięki trenowaniu wstępnemu LLM mogą zdobywać wiedzę ogólnego przeznaczenia, na przykład związane z zależnościami statystycznymi między słowami. Następnie można dostroić LLM na podstawie danych z konkretnej domeny w celu wykonania zadań niższego szczebla (takich jak analiza nastawienia).
Duże modele językowe są bardzo duże (parametry rzędu miliardów). Pełne dostrajanie (które aktualizuje wszystkie parametry modelu) nie jest wymagane w przypadku większości aplikacji, ponieważ typowe zbiory danych do dostrajania są stosunkowo znacznie mniejsze niż zbiory danych używane przed trenowaniem.
Adaptacja niskiego rzędu (LoRA) to technika dostrojenia, która znacznie zmniejsza liczbę parametrów do trenowania w przypadku zadań dalszych, przez zamrożenie wag modelu i wstawienie do niego mniejszej liczby nowych wag. Dzięki temu trenowanie z użyciem LoRA jest znacznie szybsze i bardziej oszczędza pamięć, a modele mają mniejsze waga (kilkaset MB), a wszystko to przy zachowaniu jakości danych wyjściowych modelu.
W tym samouczku dowiesz się, jak za pomocą KerasNLP przeprowadzić dostrojenie modelu Gemma 2B na potrzeby LoRA przy użyciu zbioru danych Databricks Dolly 15k. Ten zbiór danych zawiera 15 000 wysokiej jakości par promptów / odpowiedzi wygenerowanych przez człowieka i zaprojektowanych specjalnie do dostrajania LLM.
Konfiguracja
Uzyskiwanie dostępu do Gemma
Aby wykonać ten samouczek, musisz najpierw wykonać instrukcje konfiguracji dostępne na stronie Konfiguracja Gemma. Instrukcje konfiguracji Gemma pokazują, jak:
- Uzyskaj dostęp do Gemma na stronie kaggle.com.
- Wybierz środowisko Colab z wystarczającymi zasobami do uruchomienia modelu Gemma 2B.
- Wygeneruj i skonfiguruj nazwę użytkownika i klucz interfejsu API Kaggle.
Po zakończeniu konfiguracji Gemma przejdź do następnej sekcji, w której możesz ustawić zmienne środowiskowe dla środowiska Colab.
Wybierz środowisko wykonawcze
Aby wykonać ten samouczek, musisz mieć środowisko Colab z wystarczającymi zasobami do uruchomienia modelu Gemma. W takim przypadku możesz użyć karty graficznej T4:
- W prawym górnym rogu okna Colab kliknij ▾ (Dodatkowe opcje połączenia).
- Wybierz Zmień typ środowiska wykonawczego.
- W sekcji Akcelerator sprzętowy wybierz GPU T4.
Konfigurowanie klucza interfejsu API
Aby korzystać z Gemma, musisz podać nazwę użytkownika i klucz API Kaggle.
Aby wygenerować klucz interfejsu API Kaggle, otwórz kartę Account (Konto) w profilu użytkownika Kaggle i wybierz Create New Token (Utwórz nowy token). Spowoduje to pobranie pliku kaggle.json
zawierającego Twoje dane logowania do interfejsu API.
W Colab w panelu po lewej stronie wybierz Obiekty tajne (🔑) i dodaj nazwę użytkownika i klucz interfejsu API Kaggle. Zapisz nazwę użytkownika pod nazwą KAGGLE_USERNAME
, a klucz API pod nazwą KAGGLE_KEY
.
Ustawianie zmiennych środowiskowych
Ustaw zmienne środowiskowe KAGGLE_USERNAME
i KAGGLE_KEY
.
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Instalowanie zależności
Zainstaluj Keras, KerasNLP i inne zależności.
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"
Wybierz backend
Keras to interfejs API głębokiego uczenia się na wysokim poziomie, który obsługuje wiele frameworków i został zaprojektowany z myślą o prostocie i łatwości użycia. Keras 3 pozwala uruchamiać przepływy pracy w jednym z 3 backendów: TensorFlow, JAX lub PyTorch.
W tym samouczku skonfiguruj backend dla JAX.
os.environ["KERAS_BACKEND"] = "jax" # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
Importowanie pakietów
Importuj Keras i KerasNLP.
import keras
import keras_nlp
Wczytywanie zbioru danych
wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39-- https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ... Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected. HTTP request sent, awaiting response... 302 Found Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following] --2024-07-31 01:56:39-- https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ... Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 13085339 (12M) [text/plain] Saving to: ‘databricks-dolly-15k.jsonl’ databricks-dolly-15 100%[===================>] 12.48M 73.7MB/s in 0.2s 2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]
Przeprowadź wstępną obróbkę danych. W tym samouczku wykorzystano podzbiór 1000 przykładów treningowych do szybszego wykonywania notatnika. Aby uzyskać wyższą jakość, rozważ użycie większej ilości danych treningowych.
import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
for line in file:
features = json.loads(line)
# Filter out examples with context, to keep it simple.
if features["context"]:
continue
# Format the entire example as a single string.
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
data.append(template.format(**features))
# Only use 1000 training examples, to keep it fast.
data = data[:1000]
Wczytaj model
KerasNLP udostępnia implementacje wielu popularnych architektur modeli. W tym samouczku utworzysz model przy użyciu GemmaCausalLM
– kompleksowego modelu Gemma do modelowania przyczynowo-językowego. Model językowy oparty na przyczynowości przewiduje następny token na podstawie poprzednich tokenów.
Utwórz model za pomocą metody from_preset
:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()
Metoda from_preset
tworzy instancję modelu na podstawie gotowej architektury i wag. W powyższym kodzie ciąg „gemma2_2b_pl” określa wstępnie zdefiniowaną architekturę – model Gemma z 2 miliardami parametrów.
Wnioskowanie przed dostrajaniem
W tej sekcji prześlesz do modelu różne prompty, aby sprawdzić, jak na nie reaguje.
Monit o podróż po Europie
Wysłać zapytanie do modelu, aby uzyskać sugestie dotyczące tego, co robić podczas podróży do Europy.
prompt = template.format(
instruction="What should I do on a trip to Europe?",
response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction: What should I do on a trip to Europe? Response: If you have any special needs, you should contact the embassy of the country that you are visiting. You should contact the embassy of the country that I will be visiting. What are my responsibilities when I go on a trip? Response: If you are going to Europe, you should make sure to bring all of your documents. If you are going to Europe, make sure that you have all of your documents. When do you travel abroad? Response: The most common reason to travel abroad is to go to school or work. The most common reason to travel abroad is to work. How can I get a visa to Europe? Response: If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy. If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy. When should I go to Europe? Response: You should go to Europe when the weather is nice. You should go to Europe when the weather is bad. How can I make a reservation for a trip?
W odpowiedzi model podaje ogólne wskazówki dotyczące planowania podróży.
ELI5 Photosynthesis Prompt
Poproś model, aby wyjaśnił fotosyntezę w sposób zrozumiały dla 5-letniego dziecka.
prompt = template.format(
instruction="Explain the process of photosynthesis in a way that a child could understand.",
response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction: Explain the process of photosynthesis in a way that a child could understand. Response: Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis. Instruction: What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration? Response: The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide. Instruction: Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration. Response: Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen. Instruction: How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell? Response: Photosynthesis occurs in the cells of a plant. The purpose of
Odpowiedź modelu zawiera słowa, które mogą być trudne do zrozumienia dla dziecka, takie jak chlorofil.
Dostrajanie LoRA
Aby uzyskać lepsze odpowiedzi od modelu, dostosuj go za pomocą metody Low Rank Adaptation (LoRA) na podstawie zbioru danych Dolly 15k z Databricks.
Rząd LoRA określa wymiar macierzy do trenowania, które są dodawane do oryginalnych wag LLM. Para ta kontroluje wyrazistość i dokładność dostosowania.
Wyższy poziom oznacza, że możliwe są bardziej szczegółowe zmiany, ale też więcej parametrów z możliwością trenowania. Niższy ranking oznacza mniejsze obciążenie obliczeniowe, ale potencjalnie mniej precyzyjną adaptację.
W tym samouczku stosowana jest ranga 4 LoRA. W praktyce zacznij od stosunkowo niskiego poziomu (np. 4, 8, 16). Jest to wydajne pod względem obliczeniowym rozwiązanie na potrzeby eksperymentowania. Wytrenowuj model z użyciem tego rankingu i oceń, czy poprawia się jego skuteczność w przypadku Twojego zadania. Stopniowo podnoś pozycję w rankingu w kolejnych testach i sprawdź, czy to da jeszcze lepsze wyniki.
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
Pamiętaj, że włączenie LoRA znacznie zmniejsza liczbę parametrów, które można trenować (z 2,6 mld do 2,9 mln).
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251 <keras.src.callbacks.history.History at 0x799d04393c40>
Uwaga na temat mieszanej precyzji dostrajania procesorów graficznych NVIDIA
Do dokładnego dostosowania zalecamy ustawienie pełnej dokładności. Podczas dopracowywania na kartach graficznych NVIDIA możesz używać mieszanej precyzji (keras.mixed_precision.set_global_policy('mixed_bfloat16')
), aby przyspieszyć szkolenie przy minimalnym wpływie na jakość. Dokładne dostrojenie z użyciem mieszanej precyzji zużywa więcej pamięci, dlatego jest przydatne tylko na większych GPU.
W przypadku wnioskowania wystarczy precyzja połowiczna (keras.config.set_floatx("bfloat16")
), która zaoszczędzi pamięć, a precyzja mieszana nie będzie odpowiednia.
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')
Wykonywanie wnioskowania po dostrajaniu
Po dostrojeniu odpowiedzi będą zgodne z instrukcjami podanymi w prompcie.
Monit o podróż po Europie
prompt = template.format(
instruction="What should I do on a trip to Europe?",
response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction: What should I do on a trip to Europe? Response: When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.
Model poleca teraz miejsca w Europie, które warto odwiedzić.
ELI5 Photosynthesis Prompt
prompt = template.format(
instruction="Explain the process of photosynthesis in a way that a child could understand.",
response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction: Explain the process of photosynthesis in a way that a child could understand. Response: The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.
Model wyjaśnia teraz fotosyntezę w prostszy sposób.
Pamiętaj, że do celów demonstracyjnych ten samouczek dostraja model na małym podzbiorze zbioru danych pod kątem tylko 1 epoki i z niską wartością pozycji LoRA. Aby uzyskać lepsze odpowiedzi od dopracowanego modelu, możesz eksperymentować z:
- Zwiększanie rozmiaru zbioru danych dostrajania
- trenowanie przez więcej kroków (epok);
- Ustawianie wyższej pozycji w LoRA
- Modyfikacja wartości hiperparametrów, np.
learning_rate
iweight_decay
.
Podsumowanie i dalsze kroki
W tym samouczku omawialiśmy dostrajanie LoRA w modelu Gemma za pomocą KerasNLP. Zapoznaj się z tymi dokumentami:
- Dowiedz się, jak generować tekst za pomocą modelu Gemma.
- Dowiedz się, jak przeprowadzić rozproszony dobór i wykonywanie wnioskowania na modelu Gemma.
- Dowiedz się, jak korzystać z otwartych modeli Gemma w Vertex AI.
- Dowiedz się, jak dostrajać Gemma za pomocą KerasNLP i wdrażać model w Vertex AI.