REST API: モデルのチューニング

ai.google.dev で表示 Google Colab で実行 GitHub でソースを表示

このノートブックでは、curl コマンドまたは Python リクエスト API を使用して Gemini API チューニング サービスを開始する方法を学習します。ここでは、Gemini API のテキスト生成サービスの背後にあるテキストモデルをチューニングする方法について説明します。

設定

認証

Gemini API を使用すると、独自のデータでモデルをチューニングできます。利用者自身のデータと調整済みモデルであるため、API キーが提供できるものよりも厳密なアクセス制御が必要になります。

このチュートリアルを実行する前に、プロジェクトに OAuth を設定する必要があります。

Colab でセットアップする最も簡単な方法は、client_secret.json ファイルのコンテンツを Colab の「シークレット マネージャー」(左側のパネルの鍵アイコンの下)にシークレット名 CLIENT_SECRET でコピーすることです。

この gcloud コマンドは、client_secret.json ファイルを、サービスでの認証に使用できる認証情報に変換します。

try:
  from google.colab import userdata
  import pathlib
  pathlib.Path('client_secret.json').write_text(userdata.get('CLIENT_SECRET'))

  # Use `--no-browser` in colab
  !gcloud auth application-default login --no-browser --client-id-file client_secret.json --scopes='https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/generative-language.tuning'
except ImportError:
  !gcloud auth application-default login --client-id-file client_secret.json --scopes='https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/generative-language.tuning'
You are authorizing client libraries without access to a web browser. Please run the following command on a machine with a web browser and copy its output back here. Make sure the installed gcloud version is 372.0.0 or newer.

gcloud auth application-default login --remote-bootstrap="https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=87071151422-n1a3cb6c7fvkfg4gmhdtmn5ulol2l4be.apps.googleusercontent.com&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fgenerative-language.tuning&state=QIyNibWSaTIsozjmvZEkVBo6EcoW0G&access_type=offline&code_challenge=76c1ZiGvKN8cvlYfj3BmbCwE4e7tvrlwaX3REUX25gY&code_challenge_method=S256&token_usage=remote"


Enter the output of the above command: https://localhost:8085/?state=QIyNibWSaTIsozjmvZEkVBo6EcoW0G&code=4/0AeaYSHBKrY911S466QjKQIFODoOPXlO1mWyTYYdrbELIDV6Hw2DKRAyro62BugroSvIWsA&scope=https://www.googleapis.com/auth/cloud-platform%20https://www.googleapis.com/auth/generative-language.tuning

Credentials saved to file: [/content/.config/application_default_credentials.json]

These credentials will be used by any library that requests Application Default Credentials (ADC).

CURL を使用して REST API を呼び出す

このセクションでは、REST API を呼び出す curl ステートメントの例を示します。チューニング ジョブの作成方法、ステータスの確認方法、完了後に推論呼び出しを行う方法について学習します。

変数を設定する

残りの REST API 呼び出しに使用する、繰り返しの値の変数を設定します。このコードでは、Python os ライブラリを使用して、すべてのコードセルからアクセスできる環境変数を設定しています。

これは Colab ノートブック環境に固有のものです。次のコードセルのコードは、bash ターミナルで次のコマンドを実行するのと同等です。

export access_token=$(gcloud auth application-default print-access-token)
export project_id=my-project-id
export base_url=https://generativelanguage.googleapis.com
import os

access_token = !gcloud auth application-default print-access-token
access_token = '\n'.join(access_token)

os.environ['access_token'] = access_token
os.environ['project_id'] = "[Enter your project-id here]"
os.environ['base_url'] = "https://generativelanguage.googleapis.com"

チューニング済みモデルを一覧表示する

現在利用可能なチューニング済みモデルを一覧表示して、認証の設定を確認します。


curl -X GET ${base_url}/v1beta/tunedModels \
    -H "Content-Type: application/json" \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}"

チューニング済みモデルを作成

チューニング済みモデルを作成するには、training_data フィールドでデータセットをモデルに渡す必要があります。

この例では、シーケンスの次の数値を生成するようにモデルをチューニングします。たとえば、入力が 1 の場合、モデルは 2 を出力する必要があります。入力が one hundred の場合、出力は one hundred one になります。


curl -X POST $base_url/v1beta/tunedModels \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \
    -d '
      {
        "display_name": "number generator model",
        "base_model": "models/gemini-1.0-pro-001",
        "tuning_task": {
          "hyperparameters": {
            "batch_size": 2,
            "learning_rate": 0.001,
            "epoch_count":5,
          },
          "training_data": {
            "examples": {
              "examples": [
                {
                    "text_input": "1",
                    "output": "2",
                },{
                    "text_input": "3",
                    "output": "4",
                },{
                    "text_input": "-3",
                    "output": "-2",
                },{
                    "text_input": "twenty two",
                    "output": "twenty three",
                },{
                    "text_input": "two hundred",
                    "output": "two hundred one",
                },{
                    "text_input": "ninety nine",
                    "output": "one hundred",
                },{
                    "text_input": "8",
                    "output": "9",
                },{
                    "text_input": "-98",
                    "output": "-97",
                },{
                    "text_input": "1,000",
                    "output": "1,001",
                },{
                    "text_input": "10,100,000",
                    "output": "10,100,001",
                },{
                    "text_input": "thirteen",
                    "output": "fourteen",
                },{
                    "text_input": "eighty",
                    "output": "eighty one",
                },{
                    "text_input": "one",
                    "output": "two",
                },{
                    "text_input": "three",
                    "output": "four",
                },{
                    "text_input": "seven",
                    "output": "eight",
                }
              ]
            }
          }
        }
      }' | tee tunemodel.json
{
  "name": "tunedModels/number-generator-model-dzlmi0gswwqb/operations/bvl8dymw0fhw",
  "metadata": {
    "@type": "type.googleapis.com/google.ai.generativelanguage.v1beta.CreateTunedModelMetadata",
    "totalSteps": 38,
    "tunedModel": "tunedModels/number-generator-model-dzlmi0gswwqb"
  }
}
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  2280    0   296  100  1984    611   4098 --:--:-- --:--:-- --:--:--  4720

チューニング済みモデルの状態を取得する

トレーニング中、モデルの状態は CREATING に設定され、完了すると ACTIVE に変わります。

以下に、レスポンスの JSON から生成されたモデル名を解析するための Python コードを示します。これをターミナルで実行する場合は、bash JSON パーサーを使用してレスポンスを解析できます。

import json

first_page = json.load(open('tunemodel.json'))
os.environ['modelname'] = first_page['metadata']['tunedModel']

print(os.environ['modelname'])
tunedModels/number-generator-model-dzlmi0gswwqb

モデル名を指定して別の GET リクエストを実行し、state フィールドを含むモデル メタデータを取得します。


curl -X GET ${base_url}/v1beta/${modelname} \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" | grep state
"state": "ACTIVE",
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5921    0  5921    0     0  13164      0 --:--:-- --:--:-- --:--:-- 13157

推論を実行する

チューニング ジョブが完了したら、そのジョブを使用してテキスト サービスでテキストを生成できます。ローマ数字(「63」(LXIII)など)を入力してみてください


curl -X POST $base_url/v1beta/$modelname:generateContent \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \
    -d '{
        "contents": [{
        "parts": [{
          "text": "LXIII"
          }]
        }]
        }' 2> /dev/null
{
  "candidates": [
    {
      "content": {
        "parts": [
          {
            "text": "LXIV"
          }
        ],
        "role": "model"
      },
      "finishReason": "STOP",
      "index": 0,
      "safetyRatings": [
        {
          "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
          "probability": "NEGLIGIBLE"
        },
        {
          "category": "HARM_CATEGORY_HATE_SPEECH",
          "probability": "NEGLIGIBLE"
        },
        {
          "category": "HARM_CATEGORY_HARASSMENT",
          "probability": "NEGLIGIBLE"
        },
        {
          "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
          "probability": "NEGLIGIBLE"
        }
      ]
    }
  ],
  "promptFeedback": {
    "safetyRatings": [
      {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_HARASSMENT",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "probability": "NEGLIGIBLE"
      }
    ]
  }
}

モデルからの出力は正しい場合もあれば、そうでない場合もあります。チューニング済みモデルが必要な基準を満たさない場合は、高品質のサンプルをさらに追加したり、ハイパーパラメータを微調整したり、サンプルにプリアンブルを追加したりしてみてください。最初に作成したモデルに基づいて、別のチューニング済みモデルを作成することもできます。

パフォーマンスの改善について詳しくは、チューニング ガイドをご覧ください。

Python リクエストで REST API を呼び出す

HTTP リクエストを送信できる任意のライブラリを使用して、REST API を呼び出すことができます。次の一連の例では、Python リクエスト ライブラリを使用し、より高度な機能の一部を示します。

変数を設定する

access_token = !gcloud auth application-default print-access-token
access_token = '\n'.join(access_token)

project = '[Enter your project-id here]'
base_url = "https://generativelanguage.googleapis.com"

requests ライブラリをインポートします。

import requests
import json

チューニング済みモデルを一覧表示する

現在利用可能なチューニング済みモデルを一覧表示して、認証の設定を確認します。

headers={
  'Authorization': 'Bearer ' + access_token,
  'Content-Type': 'application/json',
  'x-goog-user-project': project
}

result = requests.get(
  url=f'{base_url}/v1beta/tunedModels',
  headers = headers,
)
result.json()

チューニング済みモデルを作成

Curl の例と同様に、training_data フィールドを介してデータセットを渡します。

operation = requests.post(
    url = f'{base_url}/v1beta/tunedModels',
    headers=headers,
    json= {
        "display_name": "number generator",
        "base_model": "models/gemini-1.0-pro-001",
        "tuning_task": {
          "hyperparameters": {
            "batch_size": 4,
            "learning_rate": 0.001,
            "epoch_count":5,
          },
          "training_data": {
            "examples": {
              "examples": [
                {
                    'text_input': '1',
                    'output': '2',
                },{
                    'text_input': '3',
                    'output': '4',
                },{
                    'text_input': '-3',
                    'output': '-2',
                },{
                    'text_input': 'twenty two',
                    'output': 'twenty three',
                },{
                    'text_input': 'two hundred',
                    'output': 'two hundred one',
                },{
                    'text_input': 'ninety nine',
                    'output': 'one hundred',
                },{
                    'text_input': '8',
                    'output': '9',
                },{
                    'text_input': '-98',
                    'output': '-97',
                },{
                    'text_input': '1,000',
                    'output': '1,001',
                },{
                    'text_input': '10,100,000',
                    'output': '10,100,001',
                },{
                    'text_input': 'thirteen',
                    'output': 'fourteen',
                },{
                    'text_input': 'eighty',
                    'output': 'eighty one',
                },{
                    'text_input': 'one',
                    'output': 'two',
                },{
                    'text_input': 'three',
                    'output': 'four',
                },{
                    'text_input': 'seven',
                    'output': 'eight',
                }
              ]
            }
          }
        }
      }
)
operation
<Response [200]>
operation.json()
{'name': 'tunedModels/number-generator-wl1qr34x2py/operations/41vni3zk0a47',
 'metadata': {'@type': 'type.googleapis.com/google.ai.generativelanguage.v1beta.CreateTunedModelMetadata',
  'totalSteps': 19,
  'tunedModel': 'tunedModels/number-generator-wl1qr34x2py'} }

残りの呼び出しで使用するチューニング済みモデルの名前を含む変数を設定します。

name=operation.json()["metadata"]["tunedModel"]
name
'tunedModels/number-generator-wl1qr34x2py'

チューニング済みモデルの状態を取得する

チューニング ジョブの進行状況は、状態フィールドで確認できます。CREATING は、チューニング ジョブがまだ進行中であることを意味します。ACTIVE は、トレーニングが完了し、チューニング済みモデルが使用可能であることを意味します。

tuned_model = requests.get(
    url = f'{base_url}/v1beta/{name}',
    headers=headers,
)
tuned_model.json()

以下のコードは、CREATING 状態ではなくなるまで 5 秒ごとに状態フィールドをチェックします。

import time
import pprint

op_json = operation.json()
response = op_json.get('response')
error = op_json.get('error')

while response is None and error is None:
    time.sleep(5)

    operation = requests.get(
        url = f'{base_url}/v1/{op_json["name"]}',
        headers=headers,
    )

    op_json = operation.json()
    response = op_json.get('response')
    error = op_json.get('error')

    percent = op_json['metadata'].get('completedPercent')
    if percent is not None:
      print(f"{percent:.2f}% - {op_json['metadata']['snapshots'][-1]}")
      print()

if error is not None:
    raise Exception(error)
100.00% - {'step': 19, 'epoch': 5, 'meanLoss': 1.402067, 'computeTime': '2024-03-14T15:11:23.766989274Z'}

推論を実行する

チューニング ジョブが完了したら、そのジョブを使用して、ベース テキストモデルを使用するのと同じ方法でテキストを生成できます。「6(六)」など、日本語の数字を入力してみてください。

import time

m = requests.post(
    url = f'{base_url}/v1beta/{name}:generateContent',
    headers=headers,
    json= {
         "contents": [{
             "parts": [{
                 "text": "六"
             }]
          }]
    })
import pprint
pprint.pprint(m.json())
{'candidates': [{'content': {'parts': [{'text': '七'}], 'role': 'model'},
                 'finishReason': 'STOP',
                 'index': 0,
                 'safetyRatings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
                                    'probability': 'NEGLIGIBLE'},
                                   {'category': 'HARM_CATEGORY_HATE_SPEECH',
                                    'probability': 'NEGLIGIBLE'},
                                   {'category': 'HARM_CATEGORY_HARASSMENT',
                                    'probability': 'LOW'},
                                   {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
                                    'probability': 'NEGLIGIBLE'}]}],
 'promptFeedback': {'safetyRatings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
                                       'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_HATE_SPEECH',
                                       'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_HARASSMENT',
                                       'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
                                       'probability': 'NEGLIGIBLE'}]} }

モデルからの出力は正しい場合もあれば、そうでない場合もあります。チューニング済みモデルが必要な基準を満たさない場合は、高品質のサンプルをさらに追加したり、ハイパーパラメータを微調整したり、サンプルにプリアンブルを追加したりしてみてください。

おわりに

トレーニング データにはローマ数字や日本語の数字がまったく含まれていませんでしたが、モデルはファインチューニング後にかなり一般化できました。これにより、ユースケースに合わせてモデルを微調整できます。

次のステップ

Gemini API 用の Python SDK を利用してチューニング サービスを使用する方法については、Python によるチューニング クイックスタートをご覧ください。Gemini API で他のサービスを使用する方法については、Python クイックスタートをご覧ください。