Visualizza su ai.google.dev | Esegui in Google Colab | Apri in Vertex AI | Visualizza il codice sorgente su GitHub |
Questo tutorial dimostra come eseguire campionamenti/inferenze di base con il modello 2B Instruct di RecurrentGemma utilizzando la libreria recurrentgemma
di Google DeepMind scritta con JAX (una libreria di calcolo numerico ad alte prestazioni), Flax (la libreria di rete neurale basata su JAX), Orbax (una libreria basata su JAX per l'addestramento e utilità {11token1izer}per l'addestramento).SentencePiece Sebbene Flax non sia utilizzato direttamente in questo blocco note, Flax è stato utilizzato per creare Gemma e RecurrentGemma (il modello Griffin).
Questo blocco note può essere eseguito su Google Colab con la GPU T4 (vai a Modifica > Impostazioni blocco note > nella sezione Acceleratore hardware seleziona GPU T4).
Configurazione
Le sezioni seguenti spiegano i passaggi per preparare un blocco note all'utilizzo di un modello RecurrentGemma, inclusi l'accesso al modello, l'ottenimento di una chiave API e la configurazione del runtime del blocco note
Configura l'accesso a Kaggle per Gemma
Per completare questo tutorial, devi prima seguire le istruzioni di configurazione simili alla configurazione di Gemma con alcune eccezioni:
- Ottieni l'accesso a RecurrentGemma (invece di Gemma) su kaggle.com.
- Seleziona un runtime Colab con risorse sufficienti per eseguire il modello RecurrentGemma.
- Genera e configura un nome utente e una chiave API Kaggle.
Dopo aver completato la configurazione di RecurrentGemma, passa alla sezione successiva, in cui imposterai le variabili di ambiente per il tuo ambiente Colab.
Imposta le variabili di ambiente
Imposta le variabili di ambiente per KAGGLE_USERNAME
e KAGGLE_KEY
. Quando viene visualizzata la richiesta "Vuoi concedere l'accesso?" messaggi, accetti di fornire l'accesso al secret.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Installa la libreria recurrentgemma
Questo blocco note è incentrato sull'utilizzo di una GPU Colab senza costi. Per attivare l'accelerazione hardware, fai clic su Modifica > Impostazioni blocco note > Seleziona GPU T4 > Salva.
Successivamente, devi installare la libreria Google DeepMind recurrentgemma
da github.com/google-deepmind/recurrentgemma
. Se ricevi un errore relativo al " resolver di dipendenze di pip", in genere puoi ignorarlo.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
Carica e prepara il modello RecurrentGemma
- Carica il modello RecurrentGemma con
kagglehub.model_download
, che accetta tre argomenti:
handle
: l'handle del modello di Kagglepath
: (stringa facoltativa) il percorso localeforce_download
: (booleano facoltativo) forza a scaricare di nuovo il modello
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- Controlla la posizione dei pesi del modello e del tokenizzatore, quindi imposta le variabili di percorso. La directory del tokenizzatore si troverà nella directory principale in cui hai scaricato il modello, mentre i pesi del modello saranno in una sottodirectory. Ad esempio:
- Il file
tokenizer.model
sarà in/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
). - il checkpoint del modello sarà in
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Eseguire campionamento/inferenza
- Carica il checkpoint del modello RecurrentGemma con il metodo
recurrentgemma.jax.load_parameters
. L'argomentosharding
impostato su"single_device"
carica tutti i parametri del modello su un singolo dispositivo.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- Carica il tokenizzatore del modello RecurrentGemma, creato utilizzando
sentencepiece.SentencePieceProcessor
:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- Per caricare automaticamente la configurazione corretta dal checkpoint del modello RecurrentGemma, utilizza
recurrentgemma.GriffinConfig.from_flax_params_or_variables
. Quindi, crea un'istanza del modello Griffin conrecurrentgemma.jax.Griffin
.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- Crea un
sampler
conrecurrentgemma.jax.Sampler
sopra il checkpoint/le ponderazioni del modello RecurrentGemma e il tokenizzatore:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- Scrivi un prompt in
prompt
ed esegui l'inferenza. Puoi modificaretotal_generation_steps
(il numero di passaggi eseguiti durante la generazione di una risposta; questo esempio utilizza50
per preservare la memoria dell'host).
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
Scopri di più
- Puoi scoprire di più sulla libreria
recurrentgemma
di Google DeepMind su GitHub, che contiene docstring dei metodi e dei moduli che hai utilizzato in questo tutorial, ad esempiorecurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
erecurrentgemma.jax.Sampler
. - Le seguenti librerie dispongono di siti di documentazione proprietari: JAX di base, Flax e Orbax.
- Per la documentazione relativa al tokenizzatore/detokenizzatore
sentencepiece
, consulta il repository GitHubsentencepiece
di Google. - Per la documentazione relativa a
kagglehub
, dai un'occhiata aREADME.md
nel repository GitHubkagglehub
di Kaggle. - Scopri come utilizzare i modelli Gemma con Vertex AI di Google Cloud.
- Dai un'occhiata a RecurrentGemma: Moving Past Transformers for Efficient Open Language Models di Google DeepMind.
- Leggi il documento Griffin: mixare le ricorrenza lineari recintate con Documento Local Attention for Efficient Language Models di GoogleDeepMind per saperne di più sull'architettura del modello utilizzata da RecurrentGemma.