Esegui le inferenze con Gemma utilizzando Keras

Visualizza su ai.google.dev Esegui in Google Colab Apri in Vertex AI Visualizza il codice sorgente su GitHub

Questo tutorial mostra come utilizzare Gemma con KerasNLP per eseguire inferenze e generare testo. Gemma è una famiglia di modelli aperti leggeri e all'avanguardia basati sulla stessa ricerca e tecnologia utilizzate per creare i modelli Gemini. KerasNLP è una raccolta di modelli di elaborazione del linguaggio naturale (NLP) implementati in Keras ed eseguibili su JAX, PyTorch e TensorFlow.

In questo tutorial utilizzerai Gemma per generare risposte testuali a diversi prompt. Se sei un nuovo utente di Keras, potresti voler leggere la Guida introduttiva a Keras prima di iniziare, ma non è obbligatorio. Imparerai di più su Keras mentre lavori in questo tutorial.

Imposta

Configurazione di Gemma

Per completare questo tutorial, devi prima completare le istruzioni di configurazione nella pagina di configurazione di Gemma. Le istruzioni di configurazione di Gemma mostrano come fare:

  • Accedi a Gemma su kaggle.com.
  • Seleziona un runtime Colab con risorse sufficienti per eseguire il modello Gemma 2B.
  • Genera e configura un nome utente e una chiave API Kaggle.

Dopo aver completato la configurazione di Gemma, passa alla sezione successiva, in cui imposterai le variabili di ambiente per il tuo ambiente Colab.

Imposta le variabili di ambiente

Imposta le variabili di ambiente per KAGGLE_USERNAME e KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installa le dipendenze

Installare Keras e KerasNLP.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

Seleziona un servizio di backend

Keras è un'API di deep learning multi-framework di alto livello progettata per la semplicità e la facilità d'uso. Keras 3 ti consente di scegliere il backend: TensorFlow, JAX o PyTorch. Per questo tutorial sono validi tutti e tre.

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Importa pacchetti

Importare Keras e KerasNLP.

import keras
import keras_nlp

Creare un modello

KerasNLP fornisce implementazioni di molte architetture di modelli popolari. In questo tutorial, creerai un modello utilizzando GemmaCausalLM, un modello Gemma end-to-end per la creazione di modelli linguistici causali. Un modello linguistico causale prevede il token successivo in base a quelli precedenti.

Crea il modello utilizzando il metodo from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

La funzione GemmaCausalLM.from_preset() crea un'istanza del modello da un'architettura e pesi predefiniti. Nel codice precedente, la stringa "gemma_2b_en" specifica il preset del modello Gemma 2B con 2 miliardi di parametri. Sono disponibili anche modelli Gemma con parametri 7B, 9B e 27B. Puoi trovare le stringhe di codice per i modelli Gemma negli elenchi Varianti del modello su kaggle.com.

Usa summary per avere maggiori informazioni sul modello:

gemma_lm.summary()

Come puoi vedere dal riepilogo, il modello ha 2,5 miliardi di parametri addestrabili.

Genera testo

Ora è il momento di generare del testo. Il modello ha un metodo generate che genera testo in base a un prompt. L'argomento facoltativo max_length specifica la lunghezza massima della sequenza generata.

Prova con il prompt "What is the meaning of life?".

gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

Prova a chiamare di nuovo generate con un prompt diverso.

gemma_lm.generate("How does the brain work?", max_length=64)
'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

Se sono in esecuzione su backend JAX o TensorFlow, noterai che la seconda chiamata a generate viene restituita quasi all'istante. Questo perché ogni chiamata a generate per una determinata dimensione del batch e max_length è compilata con XLA. La prima esecuzione è costosa, ma le esecuzioni successive sono molto più veloci.

Puoi anche fornire prompt in batch utilizzando un elenco come input:

gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)
['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

(Facoltativo) Prova con un'anteprima diversa

Puoi controllare la strategia di generazione per GemmaCausalLM impostando l'argomento sampler su compile(). Per impostazione predefinita, verrà utilizzato il campionamento di "greedy".

Come esperimento, prova a impostare una strategia "top_k":

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life? That was a question I asked myself as I was driving home from work one night in 2012. I was driving through the city of San Bernardino, and all I could think was, “What the heck am I doing?”\n\nMy life was completely different. I'

Mentre l'algoritmo greedy predefinito sceglie sempre il token con la probabilità più alta, l'algoritmo top-K sceglie in modo casuale il token successivo tra i token con la probabilità di top-K.

Non devi specificare un campionatore e puoi ignorare l'ultimo snippet di codice se non è utile per il tuo caso d'uso. Per saperne di più sui Samplers disponibili, visita la pagina Samplers.

Passaggi successivi

In questo tutorial hai imparato a generare testo utilizzando KerasNLP e Gemma. Ecco alcuni suggerimenti su cosa imparare in seguito: