Wyświetl na ai.google.dev | Uruchom w Google Colab | Otwórz w Vertex AI | Wyświetl źródło w GitHubie |
Ten samouczek pokazuje, jak przeprowadzać podstawowe próbkowanie/wnioskowanie za pomocą modelu RecurrentGemma 2B Instruct przy użyciu biblioteki Google DeepMind recurrentgemma
, która została napisana za pomocą JAX (zaawansowanej biblioteki liczbowej{/1), Flax (biblioteki sieci neuronowej opartej na JAX), OrbaxSentencePiece Choć nie jest on używany bezpośrednio w tym notatniku, wykorzystano go do stworzenia Gemma i RecurrentGemma (modelu Griffina).
Ten notatnik może działać w Google Colab z GPU T4 (kliknij Edytuj > Ustawienia notatnika > w sekcji Akcelerator sprzętowy wybierz GPU T4).
Konfiguracja
W poniższych sekcjach opisano kroki przygotowywania notatnika do użycia modelu RecurrentGemma, w tym dostęp do modelu, uzyskiwanie klucza interfejsu API i konfigurowanie środowiska wykonawczego notatnika
Konfigurowanie dostępu do Kaggle dla Gemma
Aby ukończyć ten samouczek, musisz najpierw wykonać instrukcje konfiguracji podobne do konfiguracji Gemma, z kilkoma wyjątkami:
- Uzyskaj dostęp do RecurrentGemma (zamiast Gemma) na kaggle.com.
- Wybierz środowisko wykonawcze Colab z wystarczającą ilością zasobów do uruchomienia modelu RecurrentGemma.
- Wygeneruj i skonfiguruj nazwę użytkownika i klucz interfejsu API Kaggle.
Po zakończeniu konfiguracji RecurrentGemma przejdź do następnej sekcji, w której możesz ustawić zmienne środowiskowe dla środowiska Colab.
Ustawianie zmiennych środowiskowych
Ustaw zmienne środowiskowe dla interfejsów KAGGLE_USERNAME
i KAGGLE_KEY
. Kiedy pojawi się komunikat „Przyznać dostęp?”, Użytkownik wyraża zgodę na przyznanie tajnego dostępu.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Zainstaluj bibliotekę recurrentgemma
Ten notatnik dotyczy bezpłatnego GPU w Colab. Aby włączyć akcelerację sprzętową, kliknij Edytuj. Ustawienia notatnika > Wybierz T4 GPU > Zapisz.
Następnie musisz zainstalować bibliotekę Google DeepMind recurrentgemma
ze strony github.com/google-deepmind/recurrentgemma
. Jeśli pojawi się błąd dotyczący resolvera zależności pip, zwykle możesz go zignorować.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
Wczytaj i przygotuj model RecurrentGemma
- Wczytaj model RecurrentGemma za pomocą parametru
kagglehub.model_download
, który przyjmuje 3 argumenty:
handle
: uchwyt modelu z Kagglepath
: (opcjonalny ciąg znaków) ścieżka lokalnaforce_download
: (opcjonalna wartość logiczna) wymusza ponowne pobranie modelu.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- Sprawdź lokalizację wag modelu i tokenizatora, a następnie ustaw zmienne ścieżki. Katalog tokenizera znajduje się w katalogu głównym, z którego został pobrany model, a wagi modelu – w podkatalogu. Na przykład:
- Plik
tokenizer.model
będzie w lokalizacji/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
. - Punkt kontrolny modelu będzie w:
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Przeprowadź próbkowanie/wnioskowanie
- Wczytaj punkt kontrolny modelu RecurrentGemma za pomocą metody
recurrentgemma.jax.load_parameters
. Argumentsharding
ustawiony na"single_device"
wczytuje wszystkie parametry modelu na jednym urządzeniu.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- Wczytaj tokenizację modelu RecurrentGemma utworzony za pomocą
sentencepiece.SentencePieceProcessor
:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- Aby automatycznie wczytywać prawidłową konfigurację z punktu kontrolnego modelu RecurrentGemma, użyj narzędzia
recurrentgemma.GriffinConfig.from_flax_params_or_variables
. Następnie utwórz instancję modelu Griffin za pomocąrecurrentgemma.jax.Griffin
.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- Utwórz obiekt
sampler
zrecurrentgemma.jax.Sampler
, dodając do punktu kontrolnego/wagi modelu RecurrentGemma i mechanizmu tokenizacji:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- Wpisz prompt w języku
prompt
i wykonaj wnioskowanie. Możesz dostosowaćtotal_generation_steps
(liczbę kroków wykonanych podczas generowania odpowiedzi – w tym przykładzie użyto50
do zachowania pamięci hosta).
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
Więcej informacji
- Więcej informacji o bibliotece
recurrentgemma
Google DeepMind znajdziesz na GitHubie, która zawiera ciągi dokumentów z metodami i modułami użytymi w tym samouczku, między innymirecurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
irecurrentgemma.jax.Sampler
. - Te biblioteki mają własne witryny z dokumentacją: core JAX, Flax i Orbax.
- Dokumentację tokenizacji i detokenizatora usługi
sentencepiece
znajdziesz w repozytorium Google na GitHubiesentencepiece
. - Dokumentację usługi
kagglehub
znajdziesz w witrynieREADME.md
w repozytorium GitHubkagglehub
firmy Kaggle. - Dowiedz się, jak używać modeli Gemma w Vertex AI Google Cloud.
- Obejrzyj film RecurrentGemma: Transfer Transformers raportu Google DeepMind dotyczącego Efficient Open Language Models.
- Przeczytaj artykuł Griffin: miksowanie powtarzanych powtarzanych utworów z Dokument „Local Attention for Efficient Language Models” przygotowany przez GoogleDeepMind, aby dowiedzieć się więcej o architekturze modelu wykorzystywanej przez firmę RecurrentGemma.