Get started with Gemma using KerasNLP

View on ai.google.dev Run in Google Colab Open in Vertex AI View source on GitHub

This tutorial shows you how to get started with Gemma using KerasNLP. Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models. KerasNLP is a collection of natural language processing (NLP) models implemented in Keras and runnable on JAX, PyTorch, and TensorFlow.

In this tutorial, you'll use Gemma to generate text responses to several prompts. If you're new to Keras, you might want to read Getting started with Keras before you begin, but you don't have to. You'll learn more about Keras as you work through this tutorial.

Setup

Gemma setup

To complete this tutorial, you'll first need to complete the setup instructions at Gemma setup. The Gemma setup instructions show you how to do the following:

  • Get access to Gemma on kaggle.com.
  • Select a Colab runtime with sufficient resources to run the Gemma 2B model.
  • Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

Set environment variables

Set environment variables for KAGGLE_USERNAME and KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Install dependencies

Install Keras and KerasNLP.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Keras 3 lets you choose the backend: TensorFlow, JAX, or PyTorch. All three will work for this tutorial.

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Import packages

Import Keras and KerasNLP.

import keras
import keras_nlp

Create a model

KerasNLP provides implementations of many popular model architectures. In this tutorial, you'll create a model using GemmaCausalLM, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the from_preset method:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

The GemmaCausalLM.from_preset() function instantiates the model from a preset architecture and weights. In the code above, the string "gemma2_2b_en" specifies the preset the Gemma 2 2B model with 2 billion parameters. Gemma models with 7B, 9B, and 27B parameters are also available. You can find the code strings for Gemma models in their Model Variation listings on Kaggle.

Use summary to get more info about the model:

gemma_lm.summary()

As you can see from the summary, the model has 2.6 billion trainable parameters.

Generate text

Now it's time to generate some text! The model has a generate method that generates text based on a prompt. The optional max_length argument specifies the maximum length of the generated sequence.

Try it out with the prompt "what is keras in 3 bullet points?".

gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)
'what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n'

Try calling generate again with a different prompt.

gemma_lm.generate("The universe is", max_length=64)
'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now'

If you're running on JAX or TensorFlow backends, you'll notice that the second generate call returns nearly instantly. This is because each call to generate for a given batch size and max_length is compiled with XLA. The first run is expensive, but subsequent runs are much faster.

You can also provide batched prompts using a list as input:

gemma_lm.generate(
    ["what is keras in 3 bullet points?",
     "The universe is"],
    max_length=64)
['what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n',
 'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now']

Optional: Try a different sampler

You can control the generation strategy for GemmaCausalLM by setting the sampler argument on compile(). By default, "greedy" sampling will be used.

As an experiment, try setting a "top_k" strategy:

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("The universe is", max_length=64)
'The universe is a big place, and there are so many things we do not know or understand about it.\n\nBut we can learn a lot about our world by studying what is known to us.\n\nFor example, if you look at the moon, it has many features that can be seen from the surface.'

While the default greedy algorithm always picks the token with the largest probability, the top-K algorithm randomly picks the next token from the tokens of top K probability.

You don't have to specify a sampler, and you can ignore the last code snippet if it's not helpful to your use case. If you'd like learn more about the available samplers, see Samplers.

What's next

In this tutorial, you learned how to generate text using KerasNLP and Gemma. Here are a few suggestions for what to learn next: