Inferenza con Gemma utilizzando JAX e Lino

Visualizza su ai.google.dev Esegui in Google Colab Apri in Vertex AI Visualizza il codice sorgente su GitHub

Panoramica

Gemma è una famiglia di modelli linguistici di grandi dimensioni aperti, leggeri e all'avanguardia, basati sulla ricerca e sulla tecnologia di Google DeepMind Gemini. Questo tutorial mostra come eseguire il campionamento/l'inferenza di base con il modello Gemma 2B Instruct utilizzando la libreria gemma di Google DeepMind scritta con JAX (una libreria di calcolo numerico ad alte prestazioni), Flax (la libreria di rete neurale basata su JAX), Orbax (una libreria basata su JAX per utilità di addestramento come il checkpointing) e SentencePiece Sebbene Lino non sia utilizzato direttamente in questo blocco note, è stato utilizzato per creare Gemma.

Questo blocco note può essere eseguito su Google Colab con GPU T4 senza costi (vai a Modifica > Impostazioni blocco note > nella sezione Acceleratore hardware seleziona GPU T4).

Configurazione

1. Configurare l'accesso di Kaggle per Gemma

Per completare questo tutorial, devi prima seguire le istruzioni di configurazione di Gemma che mostrano come effettuare le seguenti operazioni:

  • Accedi a Gemma su kaggle.com.
  • Seleziona un runtime Colab con risorse sufficienti per eseguire il modello Gemma.
  • Genera e configura un nome utente e una chiave API di Kaggle.

Dopo aver completato la configurazione di Gemma, passa alla sezione successiva, dove potrai impostare le variabili di ambiente per l'ambiente Colab.

2. Imposta le variabili di ambiente

Imposta le variabili di ambiente per KAGGLE_USERNAME e KAGGLE_KEY. Quando viene visualizzato il messaggio "Concedi l'accesso?", accetta di fornire l'accesso segreto.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Installa la libreria gemma

Questo blocco note è incentrato sull'utilizzo di una GPU Colab senza costi. Per attivare l'accelerazione hardware, fai clic su Modifica > Impostazioni blocco note > Seleziona GPU T4 > Salva.

Dopodiché devi installare la libreria Google DeepMind gemma da github.com/google-deepmind/gemma. Se viene visualizzato un errore relativo al " resolver di dipendenze di pip", generalmente puoi ignorarlo.

pip install -q git+https://github.com/google-deepmind/gemma.git

Carica e prepara il modello Gemma

  1. Carica il modello Gemma con kagglehub.model_download, che accetta tre argomenti:
  • handle: l'handle del modello di Kaggle
  • path: (stringa facoltativa) il percorso locale
  • force_download: (booleano facoltativo) costringe a scaricare di nuovo il modello
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. Controlla la posizione delle ponderazioni del modello e del tokenizzatore, quindi imposta le variabili del percorso. La directory del tokenizzatore si trova nella directory principale in cui hai scaricato il modello, mentre i pesi del modello si trovano in una sottodirectory. Ad esempio:
  • Il file tokenizer.model si troverà in /LOCAL/PATH/TO/gemma/flax/2b-it/2).
  • Il checkpoint del modello si trova in /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

Esecuzione di campionamento/inferenza

  1. Carica e formatta il checkpoint del modello Gemma con il metodo gemma.params.load_and_format_params:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. Carica il tokenizzatore Gemma, creato utilizzando sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Per caricare automaticamente la configurazione corretta dal checkpoint del modello Gemma, utilizza gemma.transformer.TransformerConfig. L'argomento cache_size è il numero di passaggi temporali nella cache Transformer di Gemma. Successivamente, crea l'istanza del modello Gemma come transformer con gemma.transformer.Transformer (che eredita da flax.linen.Module).
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. Crea un sampler con gemma.sampler.Sampler sopra il checkpoint/i pesi del modello Gemma e il tokenizzatore:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. Scrivi un prompt in input_batch ed esegui l'inferenza. Puoi modificare total_generation_steps (il numero di passaggi eseguiti durante la generazione di una risposta; in questo esempio viene utilizzato 100 per preservare la memoria host).
prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. (Facoltativo) Esegui questa cella per liberare memoria se hai completato il blocco note e vuoi provare un altro prompt. Successivamente, puoi creare di nuovo un'istanza di sampler nel passaggio 3 e personalizzare ed eseguire la richiesta nel passaggio 4.
del sampler

Scopri di più