KerasNLP を使用して Gemma を使ってみる

このチュートリアルでは、KerasNLP を使用して Gemma の使用を開始する方法について説明します。Gemma は、Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーです。KerasNLP は、Keras で実装され、JAX、PyTorch、TensorFlow で実行可能な自然言語処理(NLP)モデルのコレクションです。

このチュートリアルでは、Gemma を使用して、いくつかのプロンプトに対するテキスト レスポンスを生成します。Keras を初めて使用する場合は、始める前に Keras スタートガイドをお読みになることをおすすめしますが、必須ではありません。Keras について詳しくは、このチュートリアルで説明します。

セットアップ

Gemma の設定

このチュートリアルを完了するには、まず Gemma の設定の手順を完了する必要があります。Gemma のセットアップ手順では、次の方法について説明しています。

  • kaggle.com で Gemma にアクセスします。
  • Gemma 2B モデルを実行するのに十分なリソースがある Colab ランタイムを選択します。
  • Kaggle のユーザー名と API キーを生成して構成します。

Gemma の設定が完了したら、次のセクションに進みます。ここでは、Colab 環境の環境変数を設定します。

環境変数を設定する

KAGGLE_USERNAMEKAGGLE_KEY の環境変数を設定します。

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

依存関係のインストール

Keras と KerasNLP をインストールします。

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

バックエンドを選択

Keras は、シンプルで使いやすいように設計された高レベルのマルチフレームワーク ディープラーニング API です。Keras 3 では、TensorFlow、JAX、PyTorch のバックエンドを選択できます。このチュートリアルでは、3 つすべて使用できます。

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

パッケージをインポートする

Keras と KerasNLP をインポートします。

import keras
import keras_nlp

モデルを作成する

KerasNLP には、一般的な多くのモデル アーキテクチャの実装が用意されています。このチュートリアルでは、因果言語モデル用のエンドツーエンドの Gemma モデルである GemmaCausalLM を使用してモデルを作成します。因果言語モデルは、以前のトークンに基づいて次のトークンを予測します。

from_preset メソッドを使用してモデルを作成します。

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

GemmaCausalLM.from_preset() 関数は、事前設定されたアーキテクチャと重みからモデルをインスタンス化します。上記のコードでは、文字列 "gemma2_2b_en" で、20 億個のパラメータを持つ Gemma 2 2B モデルのプリセットを指定しています。7B、9B、27B のパラメータを使用する Gemma モデルも利用できます。Gemma モデルのコード文字列は、Kaggleモデル バリエーションのリストで確認できます。

summary を使用して、モデルの詳細情報を取得します。

gemma_lm.summary()

概要からわかるように、このモデルには 26 億個のトレーニング可能なパラメータがあります。

テキストを生成

次に、テキストを生成します。このモデルには、プロンプトに基づいてテキストを生成する generate メソッドがあります。省略可能な max_length 引数には、生成されるシーケンスの最大長を指定します。

プロンプト "what is keras in 3 bullet points?" で試してみましょう。

gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)
'what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n'

別のプロンプトで generate にもう一度電話をかけてみましょう。

gemma_lm.generate("The universe is", max_length=64)
'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now'

JAX または TensorFlow バックエンドで実行している場合、2 番目の generate 呼び出しはほぼ瞬時に返されます。これは、特定のバッチサイズと max_lengthgenerate への呼び出しが XLA でコンパイルされるためです。最初の実行は費用がかかりますが、その後の実行は大幅に高速化されます。

リストを入力として使用して、バッチ処理されたプロンプトを指定することもできます。

gemma_lm.generate(
    ["what is keras in 3 bullet points?",
     "The universe is"],
    max_length=64)
['what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n',
 'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now']

省略可: 別のサンプラーを試す

compile()sampler 引数を設定することで、GemmaCausalLM の生成戦略を制御できます。デフォルトでは、"greedy" サンプリングが使用されます。

テストとして、"top_k" 戦略を設定してみてください。

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("The universe is", max_length=64)
'The universe is a big place, and there are so many things we do not know or understand about it.\n\nBut we can learn a lot about our world by studying what is known to us.\n\nFor example, if you look at the moon, it has many features that can be seen from the surface.'

デフォルトのグリーディ アルゴリズムは常に確率が最も高いトークンを選択しますが、Top-K アルゴリズムは、上位 K 個の確率のトークンから次のトークンをランダムに選択します。

サンプラーを指定する必要はありません。ユースケースに役立たない場合は、最後のコード スニペットを無視できます。使用可能なサンプラーの詳細については、サンプラーをご覧ください。

次のステップ

このチュートリアルでは、KerasNLP と Gemma を使用してテキストを生成する方法について学習しました。次に学習することをいくつかご紹介します。