微調教學課程

本教學課程將協助您開始使用 Gemini API 調整作業 透過 Python SDK 或 REST API 存取 curl。範例顯示如何調整基礎模型 與 Gemini API 文字生成服務搭配使用

前往 ai.google.dev 查看 試用 Colab 筆記本 在 GitHub 中查看筆記本

設定驗證方法

Gemini API 可讓您根據自己的資料調整模型。您擁有的資料和 調整後的模型採用了比 API 金鑰更嚴格的存取權控管。

執行本教學課程前,請先為 專案,然後下載 「OAuth 用戶端 ID」「client_secret.json」。

此 gcloud 指令會將 client_secret.json 檔案轉換為 並用於驗證服務

gcloud auth application-default login \
    --client-id-file client_secret.json \
    --scopes='https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/generative-language.tuning'

設定變數

CURL

為週期性值設定變數,用於其餘 REST API 呼叫。程式碼使用 Python os 程式庫設定環境 所有程式碼儲存格中可存取的變數。

這是 Colab 筆記本環境獨有的功能。程式碼中的 程式碼儲存格等同於在 bash 中執行下列指令 終端機。

export access_token=$(gcloud auth application-default print-access-token)
export project_id=my-project-id
export base_url=https://generativelanguage.googleapis.com
import os

access_token = !gcloud auth application-default print-access-token
access_token = '\n'.join(access_token)

os.environ['access_token'] = access_token
os.environ['project_id'] = "[Enter your project-id here]"
os.environ['base_url'] = "https://generativelanguage.googleapis.com"

Python

access_token = !gcloud auth application-default print-access-token
access_token = '\n'.join(access_token)

project = '[Enter your project-id here]'
base_url = "https://generativelanguage.googleapis.com"

匯入 requests 程式庫。

import requests
import json

列出調整後的模型

列出可用的調整後模型,驗證驗證設定。

CURL


curl -X GET ${base_url}/v1beta/tunedModels \
    -H "Content-Type: application/json" \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}"

Python

headers={
  'Authorization': 'Bearer ' + access_token,
  'Content-Type': 'application/json',
  'x-goog-user-project': project
}

result = requests.get(
  url=f'{base_url}/v1beta/tunedModels',
  headers = headers,
)
result.json()

建立調整後模型

如要建立調整後的模型,您必須將資料集傳送至 「training_data」欄位。

在這個範例中,您將調整模型,產生下一個數字 序列舉例來說,如果輸入內容是 1,則模型應輸出 2。如果 輸入內容為 one hundred,輸出內容應為 one hundred one

CURL


curl -X POST $base_url/v1beta/tunedModels \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \
    -d '
      {
        "display_name": "number generator model",
        "base_model": "models/gemini-1.0-pro-001",
        "tuning_task": {
          "hyperparameters": {
            "batch_size": 2,
            "learning_rate": 0.001,
            "epoch_count":5,
          },
          "training_data": {
            "examples": {
              "examples": [
                {
                    "text_input": "1",
                    "output": "2",
                },{
                    "text_input": "3",
                    "output": "4",
                },{
                    "text_input": "-3",
                    "output": "-2",
                },{
                    "text_input": "twenty two",
                    "output": "twenty three",
                },{
                    "text_input": "two hundred",
                    "output": "two hundred one",
                },{
                    "text_input": "ninety nine",
                    "output": "one hundred",
                },{
                    "text_input": "8",
                    "output": "9",
                },{
                    "text_input": "-98",
                    "output": "-97",
                },{
                    "text_input": "1,000",
                    "output": "1,001",
                },{
                    "text_input": "10,100,000",
                    "output": "10,100,001",
                },{
                    "text_input": "thirteen",
                    "output": "fourteen",
                },{
                    "text_input": "eighty",
                    "output": "eighty one",
                },{
                    "text_input": "one",
                    "output": "two",
                },{
                    "text_input": "three",
                    "output": "four",
                },{
                    "text_input": "seven",
                    "output": "eight",
                }
              ]
            }
          }
        }
      }' | tee tunemodel.json
{
"name": "tunedModels/number-generator-model-dzlmi0gswwqb/operations/bvl8dymw0fhw",
"metadata": {
  "@type": "type.googleapis.com/google.ai.generativelanguage.v1beta.CreateTunedModelMetadata",
  "totalSteps": 38,
  "tunedModel": "tunedModels/number-generator-model-dzlmi0gswwqb"
}
}
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                              Dload  Upload   Total   Spent    Left  Speed
100  2280    0   296  100  1984    611   4098 --:--:-- --:--:-- --:--:--  4720

Python

operation = requests.post(
    url = f'{base_url}/v1beta/tunedModels',
    headers=headers,
    json= {
        "display_name": "number generator",
        "base_model": "models/gemini-1.0-pro-001",
        "tuning_task": {
          "hyperparameters": {
            "batch_size": 4,
            "learning_rate": 0.001,
            "epoch_count":5,
          },
          "training_data": {
            "examples": {
              "examples": [
                {
                    'text_input': '1',
                    'output': '2',
                },{
                    'text_input': '3',
                    'output': '4',
                },{
                    'text_input': '-3',
                    'output': '-2',
                },{
                    'text_input': 'twenty two',
                    'output': 'twenty three',
                },{
                    'text_input': 'two hundred',
                    'output': 'two hundred one',
                },{
                    'text_input': 'ninety nine',
                    'output': 'one hundred',
                },{
                    'text_input': '8',
                    'output': '9',
                },{
                    'text_input': '-98',
                    'output': '-97',
                },{
                    'text_input': '1,000',
                    'output': '1,001',
                },{
                    'text_input': '10,100,000',
                    'output': '10,100,001',
                },{
                    'text_input': 'thirteen',
                    'output': 'fourteen',
                },{
                    'text_input': 'eighty',
                    'output': 'eighty one',
                },{
                    'text_input': 'one',
                    'output': 'two',
                },{
                    'text_input': 'three',
                    'output': 'four',
                },{
                    'text_input': 'seven',
                    'output': 'eight',
                }
              ]
            }
          }
        }
      }
)
operation
<Response [200]>
operation.json()
{'name': 'tunedModels/number-generator-wl1qr34x2py/operations/41vni3zk0a47',
'metadata': {'@type': 'type.googleapis.com/google.ai.generativelanguage.v1beta.CreateTunedModelMetadata',
  'totalSteps': 19,
  'tunedModel': 'tunedModels/number-generator-wl1qr34x2py'} }

使用調整過的模型名稱設定變數,以便在其餘部分使用 呼叫。

name=operation.json()["metadata"]["tunedModel"]
name
'tunedModels/number-generator-wl1qr34x2py'

訓練週期數、批量和學習率的最佳值取決於週期數、批量和學習率 例如資料集和用途的其他限制如要進一步瞭解 這些值,請參閱 進階調整設定和 「超參數」

取得調整後的模型狀態

模型在訓練期間會設為 CREATING,但會變更為 完成後 ACTIVE

CURL

以下是一些 Python 程式碼,可用來剖析系統產生的模型名稱 回應 JSON如果您在終端機中執行這個指令碼,可以嘗試使用 bash 用來剖析回應的 JSON 剖析器。

import json

first_page = json.load(open('tunemodel.json'))
os.environ['modelname'] = first_page['metadata']['tunedModel']

print(os.environ['modelname'])
tunedModels/number-generator-model-dzlmi0gswwqb

再次使用模型名稱執行 GET 要求,取得模型中繼資料。 包含狀態欄位


curl -X GET ${base_url}/v1beta/${modelname} \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" | grep state
"state": "ACTIVE",
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                              Dload  Upload   Total   Spent    Left  Speed
100  5921    0  5921    0     0  13164      0 --:--:-- --:--:-- --:--:-- 13157

Python

tuned_model = requests.get(
    url = f'{base_url}/v1beta/{name}',
    headers=headers,
)
tuned_model.json()

下列程式碼每 5 秒檢查狀態欄位一次,直到不再出現為止 處於 CREATING 狀態時。

import time
import pprint

op_json = operation.json()
response = op_json.get('response')
error = op_json.get('error')

while response is None and error is None:
    time.sleep(5)

    operation = requests.get(
        url = f'{base_url}/v1/{op_json["name"]}',
        headers=headers,
    )

    op_json = operation.json()
    response = op_json.get('response')
    error = op_json.get('error')

    percent = op_json['metadata'].get('completedPercent')
    if percent is not None:
      print(f"{percent:.2f}% - {op_json['metadata']['snapshots'][-1]}")
      print()

if error is not None:
    raise Exception(error)
100.00% - {'step': 19, 'epoch': 5, 'meanLoss': 1.402067, 'computeTime': '2024-03-14T15:11:23.766989274Z'}

執行推論

調整工作完成後,即可用此工俱生成含有文字的文字 課程中也會快速介紹 Memorystore 這是 Google Cloud 的全代管 Redis 服務

CURL

試著輸入羅馬數字,例如 63 (LXIII):


curl -X POST $base_url/v1beta/$modelname:generateContent \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \
    -d '{
        "contents": [{
        "parts": [{
          "text": "LXIII"
          }]
        }]
        }' 2> /dev/null
{
"candidates": [
  {
    "content": {
      "parts": [
        {
          "text": "LXIV"
        }
      ],
      "role": "model"
    },
    "finishReason": "STOP",
    "index": 0,
    "safetyRatings": [
      {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_HARASSMENT",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "probability": "NEGLIGIBLE"
      }
    ]
  }
],
"promptFeedback": {
  "safetyRatings": [
    {
      "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
      "probability": "NEGLIGIBLE"
    },
    {
      "category": "HARM_CATEGORY_HATE_SPEECH",
      "probability": "NEGLIGIBLE"
    },
    {
      "category": "HARM_CATEGORY_HARASSMENT",
      "probability": "NEGLIGIBLE"
    },
    {
      "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
      "probability": "NEGLIGIBLE"
    }
  ]
}
}

模型的輸出內容不一定正確。如果調整過的模型 未達所需標準,您可以嘗試提高 或是調整超參數 範例。您甚至可以根據第一個步驟建立調整後的模型 已建立。

請參閱調整指南 ,進一步瞭解如何改善成效。

Python

試著輸入日文數字,例如 6 (六):

import time

m = requests.post(
    url = f'{base_url}/v1beta/{name}:generateContent',
    headers=headers,
    json= {
        "contents": [{
            "parts": [{
                "text": "六"
            }]
          }]
    })
import pprint
pprint.pprint(m.json())
{'candidates': [{'content': {'parts': [{'text': '七'}], 'role': 'model'},
                'finishReason': 'STOP',
                'index': 0,
                'safetyRatings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
                                    'probability': 'NEGLIGIBLE'},
                                  {'category': 'HARM_CATEGORY_HATE_SPEECH',
                                    'probability': 'NEGLIGIBLE'},
                                  {'category': 'HARM_CATEGORY_HARASSMENT',
                                    'probability': 'LOW'},
                                  {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
                                    'probability': 'NEGLIGIBLE'}]}],
'promptFeedback': {'safetyRatings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
                                      'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_HATE_SPEECH',
                                      'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_HARASSMENT',
                                      'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
                                      'probability': 'NEGLIGIBLE'}]} }

模型的輸出內容不一定正確。如果調整過的模型 未達所需標準,您可以嘗試提高 或是調整超參數 範例。

結論

即使訓練資料並未包含任何有關羅馬或日文的參照 但經過微調後,這個模型已有良好的一般性。如此一來 可根據用途微調模型

後續步驟

如要瞭解如何透過 Python SDK 使用調整服務,請參閱 Gemini API,請前往這裡 Python。如要瞭解操作方式 如要在 Gemini API 中使用其他服務,請前往 REST 入門指南 教學課程