Keras で LIT を使用して Gemma モデルを分析する

生成 AI で表示 Google Colab で実行 GitHub でソースを表示 Codelabs で学ぶ

はじめに

生成 AI プロダクトは比較的新しく、アプリケーションの動作は以前の形式のソフトウェアよりも大きく異なります。このため、使用されている ML モデルを精査し、モデルの動作例を調査して、想定外の結果を調査することが重要です。

Learning Interpretability Tool(LIT、ウェブサイトGitHub)は、ML モデルをデバッグおよび分析して、ML モデルが動作する理由と動作を理解するためのプラットフォームです。

この Codelab では、LIT を使用して Google の Gemma モデルをさらに活用する方法を学びます。この Codelab では、解釈可能性の手法であるシーケンスの顕著性を使用して、さまざまなプロンプト エンジニアリング アプローチを分析する方法について説明します。

学習目標

  1. シーケンスの顕著性とモデル分析におけるその用途を理解する
  2. Gemma 用に LIT を設定し、プロンプト出力とシーケンス サリエンスを計算。
  3. LM Salience モジュールでシーケンス サリエンスを使用して、プロンプト設計がモデル出力に与える影響を理解する。
  4. LIT における仮説のプロンプト改善をテストして、その効果を確認する。

注: この Codelab では、Gemma の KerasNLP 実装と、バックエンドに TensorFlow v2 を使用します。GPU カーネルを使用して実装することを強くおすすめします。

モデル分析におけるシーケンスの顕著性とその使用

Gemma などのテキストからテキストへの生成モデルは、トークン化されたテキスト形式の入力シーケンスを受け取り、その入力の一般的な後続または補完である新しいトークンを生成します。この生成は一度に 1 つずつ行われ、新しく生成された各トークンを入力と前の世代に(ループで)追加し、モデルが停止条件に達するまでこのトークンを生成します。たとえば、モデルがシーケンス終了(EOS)トークンを生成した場合や、事前定義された最大長に達した場合などです。

顕著性メソッドは、Explainable AI(XAI)手法の一種で、出力のさまざまな部分について、モデルにとって入力のどの部分が重要であるかを知ることができます。LIT は、さまざまな分類タスクの顕著性メソッドをサポートしています。これは、一連の入力トークンが予測ラベルに与える影響を説明します。シーケンス サリエンスは、これらの手法をテキストからテキストへの生成モデルに一般化し、生成されたトークンに対する前のトークンの影響を説明します。

ここでは、シーケンスのサリエンスに Grad L2 Norm メソッドを使用します。これは、モデルの勾配を分析し、先行する各トークンが出力に及ぼす影響の大きさを提供します。この方法はシンプルで効率的であり、分類やその他の設定で適切に機能することが示されています。顕著性スコアが大きいほど、影響が高くなります。この方法は、解釈可能性に関する研究コミュニティで広く知られ、広く使用されているため、LIT 内で使用されています。

勾配ベースのより高度なサリエンス手法には、Grad ⋅ Input統合勾配があります。また、LIMESHAP など、アブレーション ベースの方法も利用できます。これらは、より堅牢ですが、計算コストは大幅に高くなります。さまざまな salience の方法の詳細な比較については、こちらの記事をご覧ください。

顕著性の科学的手法については、顕著性に関するインタラクティブな探索可能な入門編をご覧ください。

インポート、環境、その他の設定コード

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 0.21.0 requires scikit-learn>=1.2.2, but you have scikit-learn 1.0.2 which is incompatible.
google-colab 1.0.0 requires ipython==7.34.0, but you have ipython 8.14.0 which is incompatible.

これらは無視してかまいません。

LIT と Keras NLP をインストールする

この Codelab では、最新バージョンの keras(3)keras-nlp(0.8.0)と lit-nlp(1.1)と、ベースモデルをダウンロードするための Kaggle アカウントが必要です。

pip install -q -U lit-nlp
pip uninstall -y umap-learn
pip install -q -U keras-nlp
pip install -q -U keras

Kaggle へのアクセス

Kaggle にログインするには、kaggle.json 認証情報ファイルを ~/.kaggle/kaggle.json に保存するか、Colab 環境で次のコマンドを実行します。詳細については、kagglehub パッケージのドキュメントをご覧ください。

import kagglehub

kagglehub.login()

Gemma の使用許諾契約にも必ず同意してください。

Gemma 用に LIT を設定する

LIT モデルのセットアップ

import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import keras_nlp

# Run at half precision.
keras.config.set_floatx("bfloat16")
model_name = 'gemma_instruct_2b_en'
gemma_model = keras_nlp.models.GemmaCausalLM.from_preset(model_name)

次のコードは、Gemma モデルで顕著性をサポートするために LIT ラッパーを初期化します。LIT フレームワークではこれらをモデルと呼びますが、この場合は、上記で読み込んだのと同じ基盤となる gemma_model の異なるエンドポイントにすぎません。これにより、LIT は世代、トークン化、顕著性をオンデマンドで計算できます。

from lit_nlp.examples.models import instrumented_keras_lms

batch_size = 1
max_sequence_length = 512
init_models = instrumented_keras_lms.initialize_model_group_for_salience
models = init_models(model_name, gemma_model,
                     batch_size=batch_size,
                     max_length=max_sequence_length)

LIT データセットの設定

Gemma は、テキスト入力を受け取ってテキスト出力を生成する、テキストからテキストへの生成モデルです。LIT のモデルは、データセットが生成をサポートするために次のフィールドを提供することを前提としています。

  • prompt: KerasGenerationModel への入力。
  • target: オプションのターゲット シーケンス(「正解」(ゴールド)の回答やモデルから事前に生成されたレスポンスなど)。

LIT には、いくつかの異なるソースの例を含む小さな sample_prompts セットが含まれています。次に例を示します。

  • [GSM8K][GSM8K]: 少数ショットの例で小学校の数学の問題を解決します。
  • [Gigaword Benchmark][gigaword]: 短い記事のコレクションに対する見出しの生成。
  • [Constitutional Prompting][constitutional-prompting]: ガイドライン/境界のあるオブジェクトの使用方法に関する新しいアイデアの生成

また、独自のデータも簡単に読み込むことができます。prompt フィールドと必要に応じて target フィールドを含むレコードを含む .jsonl ファイル([example][jsonl-example])として読み込むことも、LIT の Dataset API を使用して任意の形式で読み込むこともできます。

以下のセルを実行して、サンプル プロンプトを読み込みます。

from lit_nlp.examples.datasets import lm as lm_data

datasets = {
  'sample_prompts': lm_data.PromptExamples(
      lm_data.PromptExamples.SAMPLE_DATA_PATH
  ),
}

LIT UI のセットアップ

LIT はインタラクティブなモデル理解ツールで、人間参加型の評価とモデルの動作の調査を可能にします。LIT UI を使用すると、次のことができるため、この操作が容易になります。

  • データセットとモデルの出力をライブで
  • salience メソッドを実行し、モデルの動作に影響を与える入力トークンを理解する
  • 仮説を検証するために反事実的条件を作成する

LIT では、これらすべてを同じインターフェース内で行うことができるため、異なるツールを切り替える手間が省けます。これは、この Codelab で後ほど詳しく説明するプロンプト エンジニアリングなどのタスクで特に役立ちます。

この UI レイアウトは、他のあらゆる生成言語モデルにも使用できます。ここに記載されている機能以外の機能にご関心がある場合は、こちらで完全なリストをご覧ください。

from lit_nlp.api import layout
modules = layout.LitModuleName

LM_SALIENCE_LAYOUT = layout.LitCanonicalLayout(
    left={
        'Data Table': [modules.DataTableModule],
        'Datapoint Editor': [modules.DatapointEditorModule],
    },
    upper={  # if 'lower' not specified, this fills the right side
        'Salience': [modules.LMSalienceModule],
    },
    layoutSettings=layout.LayoutSettings(leftWidth=40),
    description='Custom layout for language model salience.',
)

このセルで LIT サーバーが初期化されます。サンプル プロンプトに対してモデルも実行され、結果がキャッシュに保存されるため、数秒かかる場合があります。

from lit_nlp import notebook as lit_notebook

lit_widget = lit_notebook.LitWidget(
    models=models,
    datasets=datasets,
    layouts={'default': LM_SALIENCE_LAYOUT},
    default_layout='default',
)

これで UI を表示できるようになりました。

lit_widget.render(height=800)
<IPython.core.display.Javascript object>

LIT を新しいタブでページ全体として開くこともできます。このセルを実行して、表示されたリンクをクリックします。

lit_widget.render(open_in_new_tab=True)
<IPython.core.display.Javascript object>

LIT における Gemma の少数ショット プロンプトの分析

現在、プロンプトは科学であると同時に芸術でもありますが、LIT を使用すると、Gemma などの大規模言語モデルのプロンプトを経験的に改善できます。ここでは、LIT を使用して Gemma の動作を調査し、潜在的な問題を予測して安全性を向上させる方法の例を示します。

複雑なプロンプトのエラーを特定する

高品質の LLM ベースのプロトタイプとアプリケーションで最も重要なプロンプト手法として、少数ショット プロンプト(プロンプトで期待される動作の例を含む)と思考の連鎖(LLM の最終出力の前に行われる説明や推論の形態を含む)の 2 つがあります。しかし、多くの場合、効果的なプロンプトの作成は依然として困難です。

好みに基づいて食べ物が好きかどうか評価してもらう例を考えてみましょう。思考の連鎖のプロンプト テンプレートの最初のプロトタイプは次のようになります。

def analyze_menu_item_template(food_likes, food_dislikes, menu_item):
  return f"""Analyze a menu item in a restaurant.

## For example:

Taste-likes: I've a sweet-tooth
Taste-dislikes: Don't like onions or garlic
Suggestion: Onion soup
Analysis: it has cooked onions in it, which you don't like.
Recommendation: You have to try it.

Taste-likes: I've a sweet-tooth
Taste-dislikes: Don't like onions or garlic
Suggestion: Baguette maison au levain
Analysis: Home-made leaven bread in france is usually great
Recommendation: Likely good.

Taste-likes: I've a sweet-tooth
Taste-dislikes: Don't like onions or garlic
Suggestion: Macaron in france
Analysis: Sweet with many kinds of flavours
Recommendation: You have to try it.

## Now analyze one more example:

Taste-likes: {food_likes}
Taste-dislikes: {food_dislikes}
Suggestion: {menu_item}
Analysis:"""

このプロンプトに問題が見つかりましたか?LIT では、LM 顕著性モジュールを使用してプロンプトを調べることができます。

シーケンス サリエンスをデバッグに使用する

このモジュールでは、モデルが回答を生成するときに処理するプロンプトの一部に焦点を当てます。顕著性は可能な限り小さいレベル(つまり、入力トークンごとに)で計算されますが、LIT はトークンの顕著性をより解釈しやすい大きなスパン(行、文、単語など)に集約できます。顕著性の詳細と、それを使用して意図しないバイアスを特定する方法については、Saliency Explorable をご覧ください。

まず、プロンプト テンプレート変数の新たな入力例をプロンプトに入力します。

food_likes = """Cheese"""
food_dislikes = """Can't eat eggs"""
menu_item = """Quiche Lorraine"""

prompt = analyze_menu_item_template(food_likes, food_dislikes, menu_item)
print(prompt)

fewshot_mistake_example = {'prompt': prompt}  # you'll use this below
Analyze a menu item in a restaurant.

## For example:

Taste-likes: I've a sweet-tooth
Taste-dislikes: Don't like onions or garlic
Suggestion: Onion soup
Analysis: it has cooked onions in it, which you don't like.
Recommendation: You have to try it.

Taste-likes: I've a sweet-tooth
Taste-dislikes: Don't like onions or garlic
Suggestion: Baguette maison au levain
Analysis: Home-made leaven bread in france is usually great
Recommendation: Likely good.

Taste-likes: I've a sweet-tooth
Taste-dislikes: Don't like onions or garlic
Suggestion: Macaron in france
Analysis: Sweet with many kinds of flavours
Recommendation: You have to try it.

## Now analyze one more example:

Taste-likes: Cheese
Taste-dislikes: Can't eat eggs
Suggestion: Quiche Lorraine
Analysis:

上のセルまたは別のタブで LIT UI を開いている場合は、LIT のデータポイント エディタを使用してこのプロンプトを追加できます。

1_Datapoint_editor.png

もう 1 つの方法は、目的のプロンプトでウィジェットを直接再レンダリングすることです。

lit_widget.render(data=[fewshot_mistake_example])
<IPython.core.display.Javascript object>

驚くべきモデルの完成度に注目してください。

Taste-likes: Cheese
Taste-dislikes: Can't eat eggs
Suggestion: Quiche Lorraine
Analysis: A savoury tart with cheese and eggs
Recommendation: You might not like it, but it's worth trying.

食べてはいけないと明確に言ったものを食べることをモデルが示唆するのはなぜですか?

シーケンスの顕著性は、少数ショットの例にある根本的な問題を明らかにするのに役立ちます。最初の例では、分析セクション it has cooked onions in it, which you don't like の Chain-of-Thought 推論が、最終的な推奨事項 You have to try it と一致していません。

LM 顕著性モジュールで [文] を選択し、推奨事項の行を選択します。UI は次のようになります。

3_few_shots_mistake..png

では、最初の例の「推奨事項」を Avoid に修正して、もう一度試してみましょう。LIT では、サンプル プロンプトにこのサンプルがプリロードされているため、次の小さなユーティリティ関数を使用して取得できます。

def get_fewshot_example(source: str) -> str:
  for example in datasets['sample_prompts'].examples:
    if example['source'] == source:
      return example['prompt']
  raise ValueError(f'Source "{source}" not found in the dataset.')
lit_widget.render(data=[{'prompt': get_fewshot_example('fewshot-fixed')}])
<IPython.core.display.Javascript object>

モデル完成は次のようになります。

Taste-likes: Cheese
Taste-dislikes: Can't eat eggs
Suggestion: Quiche Lorraine
Analysis: This dish contains eggs and cheese, which you don't like.
Recommendation: Avoid.

このことから学べる重要な教訓は、早期のプロトタイピングにより、事前に想定していなかったリスクを明らかにできるということです。言語モデルはエラーが発生しやすい性質のため、エラーをプロアクティブに設計する必要があります。これの詳細については、AI での設計に関する People + AI Guidebooks をご覧ください。

修正した少数ショット プロンプトの方が優れていますが、まだ正しいとは言えません。ユーザーに卵を避けるよう正しく伝えていますが、推論は正しくありません。実際、ユーザーは卵を食べられないと述べているのに、卵が好きではないということです。次のセクションでは、この問題を解決する方法を説明します。

仮説を検証してモデルの動作を改善する

LIT を使用すると、同じインターフェース内でプロンプトの変更をテストできます。このインスタンスでは、モデルの動作を改善するために、構成の追加をテストします。構成とは、モデルの生成の指針となる原則を含む設計プロンプトのことです。最近の手法では、憲法原則のインタラクティブな導出も可能になっています。

この考え方を活用して、プロンプトをさらに改善しましょう。プロンプトの上部に、生成の原則のセクションを追加します。プロンプトは次のように始まります。

Analyze a menu item in a restaurant.

* The analysis should be brief and to the point.
* The analysis and recommendation should both be clear about the suitability for someone with a specified dietary restriction.

## For example:

Taste-likes: I've a sweet-tooth
Taste-dislikes: Don't like onions or garlic
Suggestion: Onion soup
Analysis: it has cooked onions in it, which you don't like.
Recommendation: Avoid.

...

lit_widget.render(data=[{'prompt': get_fewshot_example('fewshot-constitution')}])
<IPython.core.display.Javascript object>

この更新により、例を再度実行すると、出力が大きく変わることがわかります。

Taste-likes: Cheese
Taste-dislikes: Can't eat eggs
Suggestion: Quiche Lorraine
Analysis: This dish containts eggs, which you can't eat.
Recommendation: Not suitable for you.

その後、プロンプトの顕著性を再調査して、この変化が発生している理由を把握できます。

3_few_shot_constitution.png

推奨の方がはるかに安全です。さらに、「あなたには適していません」は、食事制限による適合性を明確に示すという原則と、分析(いわゆる思考の連鎖)の影響を受けます。これにより、出力が正しい理由で実行されているという確信が高まります。

非技術チームをモデルのプローブと探索に参加させる

解釈可能性とは、XAI、ポリシー、法務などの専門知識に及ぶ、チームでの取り組みです。

従来、開発の初期段階でモデルを操作するには、高い技術的専門知識が必要でした。そのため、一部の共同編集者がモデルにアクセスして詳細を確認するのが難しくなっていました。従来、これらのチームが初期のプロトタイピング フェーズに参加するためのツールは存在しませんでした。

LIT を通じて、このパラダイムを変えられることを期待しています。この Codelab で学習したように、LIT の視覚的媒体と、顕著性を調べて例を探索するインタラクティブな機能は、さまざまな関係者が調査結果を共有し、伝達するのに役立ちます。これにより、より多様なチームメイトを参加させて、モデルの探索、プローブ、デバッグを行うことができます。これらの技術的な手法を知ることで、モデルの仕組みをより深く理解できます。さらに、初期モデルのテストにおけるより多様な専門知識は、改善可能な望ましくない結果を明らかにするのに役立ちます。

内容のまとめ

まとめると次のようになります。

  • LIT UI はインタラクティブなモデル実行のインターフェースを提供し、ユーザーが直接出力を生成して「もしも」シナリオをテストできる。これは、プロンプトのさまざまなバリエーションをテストする場合に特に便利です。
  • LM Salience モジュールは顕著性を視覚的に表現し、制御可能なデータ粒度を提供するため、モデル中心の構造(トークンなど)ではなく、人間中心の構造(文や単語など)についてコミュニケーションできます。

モデル評価で問題のある例が見つかった場合は、デバッグのために LIT に取り込みます。まず、モデリング タスクに論理的に関連していると思われる、理にかなった最大のコンテンツ ユニットを分析します。可視化を使用して、モデルがプロンプト コンテンツに正しく、または誤って対応する場所を確認します。次に、コンテンツの小さなユニットにドリルダウンして、修正可能な動作を特定するために、発生している誤った動作を詳しく記述します。

最後に、Lit は常に改善されています。こちらで機能の詳細やご提案をお聞かせください。

付録: LIT によるシーケンスの顕著性の計算方法

LIT は、複数ステップのプロセスでシーケンスのサリエンスを計算します。

  1. 入力文字列(プロンプトと、モデルの世代または「ゴールド」ターゲット シーケンス)を指定すると、モデル入力用にトークン化します。
  2. 入力トークンを 1 位置左にロールして、「ターゲット」シーケンスを計算します。
  3. のエンベディングを抽出し、生成シーケンスと「ターゲット」シーケンスの間のトークンごとの損失を計算します。
  4. 損失をマスクして、説明が必要なトークンを切り分けます。
  5. tf.GradientTape.gradient() 関数を使用して、マスクされた損失に関する入力エンベディングの勾配を計算します。
  6. 勾配を処理して、入力トークンごとに 1 つのスコアを付与します。たとえば、各位置での勾配の L2 ノルムを取得します。

付録: プログラムによる顕著性の計算

LIT ツールが内部で実行するのと同じ手順を使用して、Python から直接顕著性スコアを計算できます。これは次の 3 つのステップで行います。

  1. サンプルを準備してモデル トークナイザを実行する
  2. 説明する(予測された)トークンを選択するマスクを準備する
  3. salience ラッパーを呼び出します。

LIT の入力例を作成する

{'prompt': 'Keras is a',
 'target': ' deep learning library for Python that provides a wide range of tools and functionalities for building, training, and evaluating deep learning models.\n\n**'}

呼び出し規則に関する注意事項: トークナイザと salience ラッパーはどちらも LIT の Model API を使用します。ここで .predict() 関数は例(dict)のリストを受け取り、レスポンス(dict)のジェネレータを返します。これは、大規模なデータセットや低速のモデルを扱う場合には柔軟性に優れていますが、1 つの例に対する予測だけが必要な場合は、list(model.predict([example])[0] のようなコードでラップする必要があります。

説明のターゲットを選択できるようにトークンを取得する

array(['<bos>', 'K', 'eras', '▁is', '▁a', '▁deep', '▁learning',
       '▁library', '▁for', '▁Python', '▁that', '▁provides', '▁a', '▁wide',
       '▁range', '▁of', '▁tools', '▁and', '▁functionalities', '▁for',
       '▁building', ',', '▁training', ',', '▁and', '▁evaluating', '▁deep',
       '▁learning', '▁models', '.', '\n\n', '**'], dtype='<U16')

顕著性を計算するには、説明する(予測された)トークンを指定するターゲット マスクを作成する必要があります。ターゲット マスクはトークンと同じ長さの配列で、説明するトークンの位置が 1 になります。▁training▁evaluating をターゲットとして使用しましょう。

ターゲット マスクを準備する

{'prompt': 'Keras is a',
 'target': ' deep learning library for Python that provides a wide range of tools and functionalities for building, training, and evaluating deep learning models.\n\n**',
 'target_mask': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
       dtype=float32)}

顕著性モデルを呼び出す

{'grad_l2': array([45.75, 36.75, 61, 5.40625, 4.09375, 5.625, 6.46875, 7.3125, 3.375,
        5.03125, 3.23438, 4.5625, 2.375, 3.40625, 2.75, 1.97656, 3.95312,
        3.42188, 14.125, 4.53125, 11.375, 12.625, 18.5, 4.5625, 6.5, 0, 0,
        0, 0, 0, 0, 0], dtype=bfloat16),
 'grad_dot_input': array([-4.03125, 3.04688, -7.03125, -0.800781, 0.769531, -0.679688,
        -0.304688, 2.04688, 0.275391, -1.25781, -0.376953, -0.0664062,
        -0.0405273, -0.357422, 0.355469, -0.145508, -0.333984, 0.0181885,
        -5.0625, 0.235352, -0.470703, 2.25, 3.90625, -0.199219, 0.929688,
        0, 0, 0, 0, 0, 0, 0], dtype=bfloat16),
 'tokens': array(['<bos>', 'K', 'eras', '▁is', '▁a', '▁deep', '▁learning',
        '▁library', '▁for', '▁Python', '▁that', '▁provides', '▁a', '▁wide',
        '▁range', '▁of', '▁tools', '▁and', '▁functionalities', '▁for',
        '▁building', ',', '▁training', ',', '▁and', '▁evaluating', '▁deep',
        '▁learning', '▁models', '.', '\n\n', '**'], dtype='<U16')}

このように、grad_l2 フィールドと grad_dot_input フィールドのスコアは tokens に揃えられ、LIT UI に表示されるスコアと同じです。

最後のいくつかのスコアは 0 です。このモデルは左から右に記述する言語モデルであるため、ターゲット スパンの右側にあるトークンは予測に影響しません。