Przeprowadź wnioskowanie z użyciem Gemmy przy użyciu Keras

Wyświetl na ai.google.dev Uruchom w Google Colab Otwórz w Vertex AI Wyświetl źródło w GitHubie

Z tego samouczka dowiesz się, jak używać Gemma z KerasNLP do uruchamiania wnioskowania i generowania tekstu. Gemma to rodzina lekkich, nowoczesnych modeli otwartych opartych na tych samych badaniach i technologii, które posłużyły do utworzenia modeli Gemini. KerasNLP to zbiór modeli przetwarzania języka naturalnego (NLP) zaimplementowanych w Keras i możliwych do uruchomienia w językach JAX, PyTorch i TensorFlow.

W tym samouczku przy użyciu aplikacji Gemma wygenerujesz odpowiedzi tekstowe na kilka promptów. Jeśli dopiero zaczynasz korzystać z Kera, możesz zapoznać się z artykułem Pierwsze kroki z Keras, ale nie musisz tego robić. Więcej informacji o Keras znajdziesz w tym samouczku.

Konfiguracja

Konfiguracja Gemma

Aby ukończyć ten samouczek, musisz najpierw wykonać instrukcje konfiguracji opisane na stronie konfiguracji Gemma. Z instrukcji konfiguracji Gemma dowiesz się, jak:

  • Uzyskaj dostęp do Gemmy na kaggle.com.
  • Wybierz środowisko wykonawcze Colab z wystarczającymi zasobami do uruchomienia modelu Gemma 2B.
  • Wygeneruj i skonfiguruj nazwę użytkownika i klucz interfejsu API Kaggle.

Po zakończeniu konfiguracji Gemma przejdź do następnej sekcji, w której możesz ustawić zmienne środowiskowe dla środowiska Colab.

Ustawianie zmiennych środowiskowych

Ustaw zmienne środowiskowe dla interfejsów KAGGLE_USERNAME i KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Instalowanie zależności

Zainstaluj Keras i KerasNLP.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

Wybierz backend

Keras to wysokopoziomowy, wieloramowy interfejs API deep learning, który został zaprojektowany z myślą o łatwości obsługi. Keras 3 pozwala wybrać backend: TensorFlow, JAX lub PyTorch. W tym samouczku będą działać wszystkie 3 metody.

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Importuj pakiety

Importuj Keras i KerasNLP.

import keras
import keras_nlp

Tworzenie modelu

KerasNLP udostępnia implementacje wielu popularnych architektur modeli. W tym samouczku utworzysz model przy użyciu GemmaCausalLM – kompleksowego modelu Gemma do modelowania przyczynowo-językowego. Przypadkowy model językowy przewiduje kolejny token na podstawie poprzednich tokenów.

Utwórz model za pomocą metody from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

Funkcja GemmaCausalLM.from_preset() tworzy instancję modelu na podstawie gotowej architektury i wag. W powyższym kodzie ciąg "gemma_2b_en" określa gotowe ustawienie modelu Gemma 2B z 2 miliardami parametrów. Dostępne są również modele Gemma z parametrami 7B, 9B i 27B. Ciągi kodu dla modeli Gemma znajdziesz na stronie Odmiana modelu na stronie kaggle.com.

Użyj pola summary, aby uzyskać więcej informacji o modelu:

gemma_lm.summary()

Jak widać z podsumowania, model ma 2,5 mld parametrów z możliwością trenowania.

Generowanie tekstu

Teraz czas wygenerować tekst. Model zawiera metodę generate, która generuje tekst na podstawie promptu. Opcjonalny argument max_length określa maksymalną długość wygenerowanej sekwencji.

Wypróbuj ją, używając prompta "What is the meaning of life?".

gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

Spróbuj zadzwonić jeszcze raz pod numer generate, podając inny prompt.

gemma_lm.generate("How does the brain work?", max_length=64)
'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

Jeśli korzystasz z backendów JAX lub TensorFlow, możesz zauważyć, że drugie wywołanie generate zwraca niemal natychmiast. Dzieje się tak, ponieważ każde wywołanie funkcji generate dla danego rozmiaru wsadu i max_length jest skompilowane z XLA. Pierwsze uruchomienie jest drogie, ale kolejne są znacznie szybsze.

Możesz też podać prompty zbiorcze, używając listy jako danych wejściowych:

gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)
['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

Opcjonalnie: użyj innego fragmentu kodu

Możesz kontrolować strategię generowania pliku GemmaCausalLM, ustawiając argument sampler w funkcji compile(). Domyślnie używane jest próbkowanie "greedy".

W ramach eksperymentu spróbuj ustawić strategię "top_k":

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life? That was a question I asked myself as I was driving home from work one night in 2012. I was driving through the city of San Bernardino, and all I could think was, “What the heck am I doing?”\n\nMy life was completely different. I'

Domyślny algorytm zachłanności zawsze wybiera token o największym prawdopodobieństwie, a algorytm Top-K losowo wybiera token z tokenów o najwyższym prawdopodobieństwie K.

Nie musisz określać przykładowego fragmentu kodu. Możesz też zignorować ostatni fragment kodu, jeśli nie jest on przydatny w Twoim przypadku. Jeśli chcesz dowiedzieć się więcej o dostępnych fragmentach, przeczytaj artykuł Sample.

Co dalej

W tym samouczku omówiliśmy, jak wygenerować tekst przy użyciu KerasNLP i Gemma. Oto kilka sugestii, o których warto się dowiedzieć: