Wnioskowanie za pomocą algorytmu RecurrentGemma za pomocą JAX i Flax

Wyświetl na ai.google.dev Uruchom w Google Colab Otwórz w Vertex AI Wyświetl źródło w GitHubie

Ten samouczek pokazuje, jak przeprowadzać podstawowe próbkowanie/wnioskowanie za pomocą modelu RecurrentGemma 2B Instruct przy użyciu biblioteki Google DeepMind recurrentgemma, która została napisana za pomocą JAX (zaawansowanej biblioteki liczbowej{/1), Flax (biblioteki sieci neuronowej opartej na JAX), OrbaxSentencePiece Choć nie jest on używany bezpośrednio w tym notatniku, wykorzystano go do stworzenia Gemma i RecurrentGemma (modelu Griffina).

Ten notatnik może działać w Google Colab z GPU T4 (kliknij Edytuj > Ustawienia notatnika > w sekcji Akcelerator sprzętowy wybierz GPU T4).

Konfiguracja

W poniższych sekcjach opisano kroki przygotowywania notatnika do użycia modelu RecurrentGemma, w tym dostęp do modelu, uzyskiwanie klucza interfejsu API i konfigurowanie środowiska wykonawczego notatnika

Konfigurowanie dostępu do Kaggle dla Gemma

Aby ukończyć ten samouczek, musisz najpierw wykonać instrukcje konfiguracji podobne do konfiguracji Gemma, z kilkoma wyjątkami:

  • Uzyskaj dostęp do RecurrentGemma (zamiast Gemma) na kaggle.com.
  • Wybierz środowisko wykonawcze Colab z wystarczającą ilością zasobów do uruchomienia modelu RecurrentGemma.
  • Wygeneruj i skonfiguruj nazwę użytkownika i klucz interfejsu API Kaggle.

Po zakończeniu konfiguracji RecurrentGemma przejdź do następnej sekcji, w której możesz ustawić zmienne środowiskowe dla środowiska Colab.

Ustawianie zmiennych środowiskowych

Ustaw zmienne środowiskowe dla interfejsów KAGGLE_USERNAME i KAGGLE_KEY. Kiedy pojawi się komunikat „Przyznać dostęp?”, Użytkownik wyraża zgodę na przyznanie tajnego dostępu.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Zainstaluj bibliotekę recurrentgemma

Ten notatnik dotyczy bezpłatnego GPU w Colab. Aby włączyć akcelerację sprzętową, kliknij Edytuj. Ustawienia notatnika > Wybierz T4 GPU > Zapisz.

Następnie musisz zainstalować bibliotekę Google DeepMind recurrentgemma ze strony github.com/google-deepmind/recurrentgemma. Jeśli pojawi się błąd dotyczący resolvera zależności pip, zwykle możesz go zignorować.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

Wczytaj i przygotuj model RecurrentGemma

  1. Wczytaj model RecurrentGemma za pomocą parametru kagglehub.model_download, który przyjmuje 3 argumenty:
  • handle: uchwyt modelu z Kaggle
  • path: (opcjonalny ciąg znaków) ścieżka lokalna
  • force_download: (opcjonalna wartość logiczna) wymusza ponowne pobranie modelu.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. Sprawdź lokalizację wag modelu i tokenizatora, a następnie ustaw zmienne ścieżki. Katalog tokenizera znajduje się w katalogu głównym, z którego został pobrany model, a wagi modelu – w podkatalogu. Na przykład:
  • Plik tokenizer.model będzie w lokalizacji /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1.
  • Punkt kontrolny modelu będzie w: /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Przeprowadź próbkowanie/wnioskowanie

  1. Wczytaj punkt kontrolny modelu RecurrentGemma za pomocą metody recurrentgemma.jax.load_parameters. Argument sharding ustawiony na "single_device" wczytuje wszystkie parametry modelu na jednym urządzeniu.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. Wczytaj tokenizację modelu RecurrentGemma utworzony za pomocą sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Aby automatycznie wczytywać prawidłową konfigurację z punktu kontrolnego modelu RecurrentGemma, użyj narzędzia recurrentgemma.GriffinConfig.from_flax_params_or_variables. Następnie utwórz instancję modelu Griffin za pomocą recurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. Utwórz obiekt sampler z recurrentgemma.jax.Sampler, dodając do punktu kontrolnego/wagi modelu RecurrentGemma i mechanizmu tokenizacji:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. Wpisz prompt w języku prompt i wykonaj wnioskowanie. Możesz dostosować total_generation_steps (liczbę kroków wykonanych podczas generowania odpowiedzi – w tym przykładzie użyto 50 do zachowania pamięci hosta).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

Więcej informacji