Executar inferências com o Gemma usando o Keras

Ver em ai.google.dev Executar no Google Colab Abrir na Vertex AI Veja o código-fonte no GitHub

Neste tutorial, mostramos como usar o Gemma com o KerasNLP (link em inglês) para executar inferências e gerar texto. O Gemma é uma família de modelos abertos, leves e modernos, criados com a mesma pesquisa e tecnologia usada para criar os modelos do Gemini. O KerasNLP é uma coleção de modelos de processamento de linguagem natural (PLN) implementados no Keras e executáveis no JAX, no PyTorch e no TensorFlow.

Neste tutorial, você vai usar o Gemma para gerar respostas de texto a vários comandos. Se você é iniciante no Keras, talvez queira ler Introdução ao Keras antes de começar, mas não é necessário. Neste tutorial, você vai saber mais sobre a Keras.

Configuração

Configuração do Gemma

Para concluir este tutorial, primeiro você precisa concluir as instruções de configuração na configuração do Gemma. As instruções de configuração do Gemma mostram como fazer o seguinte:

  • Acesse o Gemma em kaggle.com.
  • Selecione um ambiente de execução do Colab com recursos suficientes para executar o modelo Gemma 2B.
  • Gere e configure um nome de usuário e uma chave de API do Kaggle.

Depois de concluir a configuração do Gemma, vá para a próxima seção, em que você definirá variáveis de ambiente para o ambiente do Colab.

Defina as variáveis de ambiente

Defina as variáveis de ambiente para KAGGLE_USERNAME e KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Instalar dependências

Instalar o Keras e o KerasNLP.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

Selecione um back-end

Keras é uma API de aprendizado profundo de alto nível e com vários frameworks projetadas para simplicidade e facilidade de uso. A Keras 3 permite que você escolha o back-end: TensorFlow, JAX ou PyTorch. Os três funcionam para este tutorial.

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Importar pacotes

Importe o Keras e o KerasNLP.

import keras
import keras_nlp

crie um modelo

O KerasNLP oferece implementações de várias arquiteturas de modelo conhecidas. Neste tutorial, você criará um modelo usando GemmaCausalLM, um modelo Gemma completo para modelagem de linguagem causal. Um modelo de linguagem causal prevê o próximo token com base em tokens anteriores.

Crie o modelo usando o método from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

A função GemmaCausalLM.from_preset() instancia o modelo usando uma arquitetura e pesos predefinidos. No código acima, a string "gemma_2b_en" especifica a predefinição do modelo Gemma 2B com dois bilhões de parâmetros. Os modelos Gemma com parâmetros 7B, 9B e 27B também estão disponíveis. Você pode encontrar as strings de código para modelos do Gemma nas listagens de Variações de modelos em kaggle.com (links em inglês).

Use summary para ter mais informações sobre o modelo:

gemma_lm.summary()

O modelo tem 2, 5 bilhões de parâmetros treináveis.

Gere o texto

Agora é hora de gerar texto. O modelo tem um método generate que gera texto com base em um comando. O argumento max_length opcional especifica o comprimento máximo da sequência gerada.

Teste com o comando "What is the meaning of life?".

gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

Tente chamar generate de novo com uma solicitação diferente.

gemma_lm.generate("How does the brain work?", max_length=64)
'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

Se você estiver executando em back-ends do JAX ou do TensorFlow, vai perceber que a segunda chamada generate retorna quase instantaneamente. Isso ocorre porque cada chamada para generate para um determinado tamanho de lote e max_length é compilado com XLA. A primeira execução é cara, mas as execuções subsequentes são muito mais rápidas.

Também é possível enviar comandos em lote usando uma lista como entrada:

gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)
['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

Opcional: usar outra amostra

Você pode controlar a estratégia de geração de GemmaCausalLM definindo o argumento sampler no compile(). Por padrão, a amostragem "greedy" será usada.

Como um experimento, tente definir uma estratégia "top_k":

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life? That was a question I asked myself as I was driving home from work one night in 2012. I was driving through the city of San Bernardino, and all I could think was, “What the heck am I doing?”\n\nMy life was completely different. I'

Embora o algoritmo ganancioso padrão sempre escolha o token com a maior probabilidade, o algoritmo top-K escolhe aleatoriamente o próximo token dos tokens de probabilidade Top-K.

Não é necessário especificar um sampler. Você pode ignorar o último snippet de código se ele não for útil para seu caso de uso. Para saber mais sobre os Samplers disponíveis, consulte Samplers.

A seguir

Neste tutorial, você aprendeu a gerar texto usando o KerasNLP e o Gemma. Aqui estão algumas sugestões sobre o que aprender a seguir: