KerasNLP を使用して Gemma を使ってみる

ai.google.dev で表示 Google Colab で実行 Vertex AI で開く GitHub のソースを表示

このチュートリアルでは、KerasNLP を使用して Gemma の使用を開始する方法について説明します。Gemma は、Gemini モデルの作成に使用されたものと同じ研究とテクノロジーに基づいて構築された、軽量で最先端のオープンモデルのファミリーです。KerasNLP は、Keras で実装された自然言語処理(NLP)モデルのコレクションであり、JAX、PyTorch、TensorFlow で実行できます。

このチュートリアルでは、Gemma を使用して複数のプロンプトに対するテキスト レスポンスを生成します。Keras を初めて使用する場合は、その前に Keras スタートガイドをお読みになることをおすすめしますが、必須ではありません。このチュートリアルでは、Keras についてさらに学びましょう。

セットアップ

Gemma の設定

このチュートリアルを完了するには、まず Gemma の設定にある設定手順を完了する必要があります。Gemma の設定手順では、次の方法について説明します。

  • kaggle.com で Gemma にアクセスしてください。
  • 実行するのに十分なリソースがある Colab ランタイムを選択します Gemma 2B モデルを使用します。
  • Kaggle のユーザー名と API キーを生成して構成します。

Gemma の設定が完了したら、次のセクションに進み、Colab 環境の環境変数を設定します。

環境変数を設定する

KAGGLE_USERNAMEKAGGLE_KEY の環境変数を設定します。

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

依存関係のインストール

Keras と KerasNLP をインストールする。

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

バックエンドを選択

Keras は、シンプルさと使いやすさを考慮して設計された高レベルのマルチフレームワーク ディープ ラーニング API です。Keras 3 では、バックエンド(TensorFlow、JAX、PyTorch)を選択できます。このチュートリアルでは、3 つすべてを使用できます。

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

パッケージをインポートする

Keras と KerasNLP をインポートする。

import keras
import keras_nlp

モデルを作成する

KerasNLP には、多くの一般的なモデル アーキテクチャの実装が用意されています。このチュートリアルでは、因果言語モデリング用のエンドツーエンドの Gemma モデルである GemmaCausalLM を使用してモデルを作成します。因果言語モデルは、以前のトークンに基づいて次のトークンを予測します。

from_preset メソッドを使用してモデルを作成します。

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

GemmaCausalLM.from_preset() 関数は、プリセットのアーキテクチャと重みからモデルをインスタンス化します。上記のコードでは、文字列 "gemma2_2b_en" で、20 億のパラメータがある Gemma 2 2B モデルのプリセットを指定しています。7B、9B、27B パラメータを使用した Gemma モデルも使用できます。Gemma モデルのコード文字列は、Kaggle の [Model Variation] リストで確認できます。

summary を使用して、モデルに関する詳細情報を取得します。

gemma_lm.summary()

要約からわかるように、このモデルには 26 億個のトレーニング可能なパラメータがあります。

テキストを生成

次に、テキストを生成します。このモデルには、プロンプトに基づいてテキストを生成する generate メソッドがあります。オプションの max_length 引数は、生成されるシーケンスの最大長を指定します。

プロンプト "what is keras in 3 bullet points?" で試してみましょう。

gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)
'what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n'

別のプロンプトで generate をもう一度呼び出してみてください。

gemma_lm.generate("The universe is", max_length=64)
'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now'

JAX または TensorFlow バックエンドで実行している場合、2 回目の generate 呼び出しがほぼ瞬時に返されます。これは、特定のバッチサイズと max_length に対する generate の呼び出しが XLA でコンパイルされるためです。初回の実行はコストがかかりますが、その後の実行ははるかに高速です。

リストを入力として使用し、バッチ プロンプトを指定することもできます。

gemma_lm.generate(
    ["what is keras in 3 bullet points?",
     "The universe is"],
    max_length=64)
['what is keras in 3 bullet points?\n\n[Answer 1]\n\nKeras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, Theano, or PlaidML. It is designed to be user-friendly and easy to extend.\n\n',
 'The universe is a vast and mysterious place, filled with countless stars, planets, and galaxies. But what if there was a way to see the universe in a whole new way? What if we could see the universe as it was when it was first created? What if we could see the universe as it is now']

省略可: 別のサンプラーを試す

compile()sampler 引数を設定することで、GemmaCausalLM の生成方法を制御できます。デフォルトでは、"greedy" サンプリングが使用されます。

テストとして、次のように "top_k" 戦略を設定してみましょう。

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("The universe is", max_length=64)
'The universe is a big place, and there are so many things we do not know or understand about it.\n\nBut we can learn a lot about our world by studying what is known to us.\n\nFor example, if you look at the moon, it has many features that can be seen from the surface.'

デフォルトの貪欲アルゴリズムは常に確率が最も高いトークンを選択しますが、トップ K アルゴリズムは上位 K 個の確率を持つトークンから次のトークンをランダムに選択します。

サンプラーを指定する必要はありません。ユースケースで役に立たない場合は、最後のコード スニペットを無視してかまいません。利用可能なサンプラーの詳細については、サンプラーをご覧ください。

次のステップ

このチュートリアルでは、KerasNLP と Gemma を使用してテキストを生成する方法を学習しました。次に学ぶべきことをいくつか紹介します。