Inferenza con RecurrentGemma utilizzando JAX e Flax

Visualizza su ai.google.dev Esegui in Google Colab Apri in Vertex AI Visualizza il codice sorgente su GitHub

Questo tutorial dimostra come eseguire campionamenti/inferenze di base con il modello 2B Instruct di RecurrentGemma utilizzando la libreria recurrentgemma di Google DeepMind scritta con JAX (una libreria di calcolo numerico ad alte prestazioni), Flax (la libreria di rete neurale basata su JAX), Orbax (una libreria basata su JAX per l'addestramento e utilità {11token1izer}per l'addestramento).SentencePiece Sebbene Flax non sia utilizzato direttamente in questo blocco note, Flax è stato utilizzato per creare Gemma e RecurrentGemma (il modello Griffin).

Questo blocco note può essere eseguito su Google Colab con la GPU T4 (vai a Modifica > Impostazioni blocco note > nella sezione Acceleratore hardware seleziona GPU T4).

Configurazione

Le sezioni seguenti spiegano i passaggi per preparare un blocco note all'utilizzo di un modello RecurrentGemma, inclusi l'accesso al modello, l'ottenimento di una chiave API e la configurazione del runtime del blocco note

Configura l'accesso a Kaggle per Gemma

Per completare questo tutorial, devi prima seguire le istruzioni di configurazione simili alla configurazione di Gemma con alcune eccezioni:

  • Ottieni l'accesso a RecurrentGemma (invece di Gemma) su kaggle.com.
  • Seleziona un runtime Colab con risorse sufficienti per eseguire il modello RecurrentGemma.
  • Genera e configura un nome utente e una chiave API Kaggle.

Dopo aver completato la configurazione di RecurrentGemma, passa alla sezione successiva, in cui imposterai le variabili di ambiente per il tuo ambiente Colab.

Imposta le variabili di ambiente

Imposta le variabili di ambiente per KAGGLE_USERNAME e KAGGLE_KEY. Quando viene visualizzata la richiesta "Vuoi concedere l'accesso?" messaggi, accetti di fornire l'accesso al secret.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installa la libreria recurrentgemma

Questo blocco note è incentrato sull'utilizzo di una GPU Colab senza costi. Per attivare l'accelerazione hardware, fai clic su Modifica > Impostazioni blocco note > Seleziona GPU T4 > Salva.

Successivamente, devi installare la libreria Google DeepMind recurrentgemma da github.com/google-deepmind/recurrentgemma. Se ricevi un errore relativo al " resolver di dipendenze di pip", in genere puoi ignorarlo.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

Carica e prepara il modello RecurrentGemma

  1. Carica il modello RecurrentGemma con kagglehub.model_download, che accetta tre argomenti:
  • handle: l'handle del modello di Kaggle
  • path: (stringa facoltativa) il percorso locale
  • force_download: (booleano facoltativo) forza a scaricare di nuovo il modello
di Gemini Advanced.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. Controlla la posizione dei pesi del modello e del tokenizzatore, quindi imposta le variabili di percorso. La directory del tokenizzatore si troverà nella directory principale in cui hai scaricato il modello, mentre i pesi del modello saranno in una sottodirectory. Ad esempio:
  • Il file tokenizer.model sarà in /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1).
  • il checkpoint del modello sarà in /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Eseguire campionamento/inferenza

  1. Carica il checkpoint del modello RecurrentGemma con il metodo recurrentgemma.jax.load_parameters. L'argomento sharding impostato su "single_device" carica tutti i parametri del modello su un singolo dispositivo.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. Carica il tokenizzatore del modello RecurrentGemma, creato utilizzando sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Per caricare automaticamente la configurazione corretta dal checkpoint del modello RecurrentGemma, utilizza recurrentgemma.GriffinConfig.from_flax_params_or_variables. Quindi, crea un'istanza del modello Griffin con recurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. Crea un sampler con recurrentgemma.jax.Sampler sopra il checkpoint/le ponderazioni del modello RecurrentGemma e il tokenizzatore:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. Scrivi un prompt in prompt ed esegui l'inferenza. Puoi modificare total_generation_steps (il numero di passaggi eseguiti durante la generazione di una risposta; questo esempio utilizza 50 per preservare la memoria dell'host).
di Gemini Advanced.
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

Scopri di più