REST API:調整快速入門導覽課程

前往 ai.google.dev 查看 在 Google Colab 中執行 在 GitHub 上查看原始碼 下載筆記本

在這個筆記本中,您將瞭解如何使用 curl 指令或 Python 要求 API 呼叫 PaLM REST API,以開始使用 PaLM API 調整服務。在這裡,您將瞭解如何調整 PaLM API 的文字產生服務背後的文字模型。

設定

驗證

您可以透過 PaLM API 依據自己的資料調整模型。由於 API 金鑰是您的資料和經過調整的模型,因此需要更嚴格的存取權控管機制。

在執行本教學課程前,您必須先為專案設定 OAuth

如要在 Colab 中執行這個筆記本,請先使用「File」>「Upload」選項上傳 client_secret*.json 檔案。

顯示 Colab 的「檔案」>「上傳」選項

cp client_secret*.json client_secret.json
ls
client_secret.json

這個 gcloud 指令會將 client_secret.json 檔案轉換為可用來驗證服務的憑證。

import os
if 'COLAB_RELEASE_TAG' in os.environ:
  # Use `--no-browser` in colab
  !gcloud auth application-default login --no-browser --client-id-file client_secret.json --scopes='https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/generative-language.tuning'
else:
  !gcloud auth application-default login --client-id-file client_secret.json --scopes='https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/generative-language.tuning'

使用 CURL 呼叫 REST API

本節會提供用來呼叫 REST API 的 curl 陳述式範例。您將瞭解如何建立調整工作、查看工作狀態,以及完成後發出推論呼叫。

設定變數

為週期性值設定變數,以供其餘 REST API 呼叫使用。這個程式碼使用 Python os 程式庫設定環境變數,所有程式碼儲存格皆可存取。

這會因 Colab 筆記本環境而異。下一個程式碼儲存格中的程式碼相當於在 bash 終端機中執行指令。

export access_token=$(gcloud auth application-default print-access-token)
export project_id=my-project-id
export base_url=https://generativelanguage.googleapis.com
import os

access_token = !gcloud auth application-default print-access-token
access_token = '\n'.join(access_token)

os.environ['access_token'] = access_token
os.environ['project_id'] = "project-id"
os.environ['base_url'] = "https://generativelanguage.googleapis.com"

可列出調整過的模型

列出目前可用的調整模型,驗證您的驗證設定。


curl -X GET ${base_url}/v1beta3/tunedModels \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" | grep name
"name": "tunedModels/testnumbergenerator-fvitocr834l6",
      "name": "tunedModels/my-display-name-81-9wpmc1m920vq",
      "displayName": "my display name 81",
      "name": "tunedModels/number-generator-model-kctlevca1g3q",
      "name": "tunedModels/my-display-name-81-r9wcuda14lyy",
      "displayName": "my display name 81",
      "name": "tunedModels/number-generator-model-w1eabln5adwp",
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 17583    0 17583    0     0  51600      0 --:--:-- --:--:-- --:--:-- 51563

建立經過調整的模型

如要建立調整過的模型,您必須將資料集傳遞至 training_data 欄位中的模型。

在此範例中,您將調整模型,在序列中產生下一個數字。舉例來說,如果輸入是 1,模型應輸出 2。如果輸入是 one hundred,則輸出內容應為 one hundred one


curl -X POST ${base_url}/v1beta3/tunedModels \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \
    -d '
      {
        "display_name": "number generator model",
        "base_model": "models/text-bison-001",
        "tuning_task": {
          "hyperparameters": {
            "batch_size": 2,
            "learning_rate": 0.001,
            "epoch_count":3,
          },
          "training_data": {
            "examples": {
              "examples": [
                {
                    "text_input": "1",
                    "output": "2",
                },{
                    "text_input": "3",
                    "output": "4",
                },{
                    "text_input": "-3",
                    "output": "-2",
                },{
                    "text_input": "twenty two",
                    "output": "twenty three",
                },{
                    "text_input": "two hundred",
                    "output": "two hundred one",
                },{
                    "text_input": "ninety nine",
                    "output": "one hundred",
                },{
                    "text_input": "8",
                    "output": "9",
                },{
                    "text_input": "-98",
                    "output": "-97",
                },{
                    "text_input": "1,000",
                    "output": "1,001",
                },{
                    "text_input": "10,100,000",
                    "output": "10,100,001",
                },{
                    "text_input": "thirteen",
                    "output": "fourteen",
                },{
                    "text_input": "eighty",
                    "output": "eighty one",
                },{
                    "text_input": "one",
                    "output": "two",
                },{
                    "text_input": "three",
                    "output": "four",
                },{
                    "text_input": "seven",
                    "output": "eight",
                }
              ]
            }
          }
        }
      }' | tee tunemodel.json
{
  "name": "tunedModels/number-generator-model-q2d0uism5ivd/operations/xvyx09sjxlmh",
  "metadata": {
    "@type": "type.googleapis.com/google.ai.generativelanguage.v1beta3.CreateTunedModelMetadata",
    "totalSteps": 23,
    "tunedModel": "tunedModels/number-generator-model-q2d0uism5ivd"
  }
}
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  2277    0   297  100  1980    146    975  0:00:02  0:00:02 --:--:--  1121

取得調整後模型狀態

模型在訓練期間的狀態會設為 CREATING,並於訓練完成後變更為 ACTIVE

以下是一些 Python 程式碼,用於剖析回應 JSON 中產生的模型名稱。如果您在終端機中執行這個程式碼,可以嘗試使用 bash JSON 剖析器剖析回應。

import json

first_page = json.load(open('tunemodel.json'))
os.environ['modelname'] = first_page['metadata']['tunedModel']

print(os.environ['modelname'])
tunedModels/number-generator-model-q2d0uism5ivd

使用模型名稱再次發出 GET 要求,取得包含狀態欄位的模型中繼資料。


curl -X GET ${base_url}/v1beta3/${modelname} \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \ | grep state
"state": "CREATING",
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   494    0   494    0     0    760      0 --:--:-- --:--:-- --:--:--   760
curl: (3) URL using bad/illegal format or missing URL

執行推論

調整工作完成後,您就能使用該工作產生文字服務文字。


curl -X POST ${base_url}/v1beta3/${modelname}:generateText \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \
    -d '{
        "prompt": {
              "text": "4"
              },
        "temperature": 1.0,
        "candidate_count": 2}' | grep output
"output": "3 2 1",
      "output": "3 2",
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1569    0  1447  100   122    183     15  0:00:08  0:00:07  0:00:01   310

模型的輸出結果不一定正確。如果經過調整的模型效能未達要求標準,您可以嘗試新增更多高品質的範例、調整超參數,或是在範例中加入前置參數。您甚至可以根據自己建立的第一個模型,建立其他經過調整的模型。

請參閱調整指南,進一步瞭解如何改善效能。

透過 Python 要求呼叫 REST API

您可使用任何可讓您傳送 http 要求的程式庫呼叫其餘 API。下一組範例使用 Python 要求程式庫,並示範一些更進階的功能。

設定變數

access_token = !gcloud auth application-default print-access-token
access_token = '\n'.join(access_token)

project = 'project-id'
base_url = "https://generativelanguage.googleapis.com"

匯入 requests 程式庫。

import requests
import json

可列出調整過的模型

列出目前可用的調整模型,驗證您的驗證設定。

headers={
  'Authorization': 'Bearer ' + access_token,
  'Content-Type': 'application/json',
  'x-goog-user-project': project
}

result = requests.get(
  url=f'{base_url}/v1beta3/tunedModels',
  headers = headers,
)
result.json()
{'tunedModels': [{'name': 'tunedModels/testnumbergenerator-fvitocr834l6',
   'baseModel': 'models/text-bison-001',
   'displayName': 'test_number_generator',
   'description': '{"description":"generates the  next number in the sequence given the input text","exampleInput":"input: 1","exampleOutput":"output: 2","datasourceUrl":"https://drive.google.com/open?id=11Pdm6GNom4vlBMUHwO6yFjGQT3t1yi44WVShXMFnkVA&authuser=0&resourcekey=0-2d17tccbdBoThXMkNDvtag","showedTuningComplete":false}',
   'state': 'ACTIVE',
   'createTime': '2023-09-18T11:06:39.092786Z',
   'updateTime': '2023-09-18T11:07:24.198359Z',
   'tuningTask': {'startTime': '2023-09-18T11:06:39.461814784Z',
    'completeTime': '2023-09-18T11:07:24.198359Z',
    'snapshots': [{'step': 1,
      'meanLoss': 16.613504,
      'computeTime': '2023-09-18T11:06:44.532937624Z'},
     {'step': 2,
      'epoch': 1,
      'meanLoss': 20.299532,
      'computeTime': '2023-09-18T11:06:47.825134421Z'},
     {'step': 3,
      'epoch': 1,
      'meanLoss': 8.169708,
      'computeTime': '2023-09-18T11:06:50.580344344Z'},
     {'step': 4,
      'epoch': 2,
      'meanLoss': 3.7588992,
      'computeTime': '2023-09-18T11:06:53.219133748Z'},
     {'step': 5,
      'epoch': 3,
      'meanLoss': 2.0643115,
      'computeTime': '2023-09-18T11:06:55.828458606Z'},
     {'step': 6,
      'epoch': 3,
      'meanLoss': 1.9765375,
      'computeTime': '2023-09-18T11:06:58.426053772Z'},
     {'step': 7,
      'epoch': 4,
      'meanLoss': 0.9276156,
      'computeTime': '2023-09-18T11:07:01.231832398Z'},
     {'step': 8,
      'epoch': 5,
      'meanLoss': 1.8424839,
      'computeTime': '2023-09-18T11:07:03.822710074Z'},
     {'step': 9,
      'epoch': 5,
      'meanLoss': 1.1747926,
      'computeTime': '2023-09-18T11:07:06.441685551Z'},
     {'step': 10,
      'epoch': 6,
      'meanLoss': 0.3079359,
      'computeTime': '2023-09-18T11:07:08.793491157Z'},
     {'step': 11,
      'epoch': 7,
      'meanLoss': 0.543368,
      'computeTime': '2023-09-18T11:07:11.393264892Z'},
     {'step': 12,
      'epoch': 7,
      'meanLoss': 0.35068464,
      'computeTime': '2023-09-18T11:07:13.808021238Z'},
     {'step': 13,
      'epoch': 8,
      'meanLoss': 0.026032856,
      'computeTime': '2023-09-18T11:07:16.295972078Z'},
     {'step': 14,
      'epoch': 8,
      'meanLoss': 0.108341046,
      'computeTime': '2023-09-18T11:07:18.941247488Z'},
     {'step': 15,
      'epoch': 9,
      'meanLoss': 0.016470395,
      'computeTime': '2023-09-18T11:07:21.607654306Z'},
     {'step': 16,
      'epoch': 10,
      'meanLoss': 0.063049875,
      'computeTime': '2023-09-18T11:07:24.077271307Z'}],
    'hyperparameters': {'epochCount': 10,
     'batchSize': 16,
     'learningRate': 0.02} },
   'temperature': 0.7,
   'topP': 0.95,
   'topK': 40},
  {'name': 'tunedModels/my-display-name-81-9wpmc1m920vq',
   'baseModel': 'models/text-bison-tuning-test',
   'displayName': 'my display name 81',
   'state': 'ACTIVE',
   'createTime': '2023-09-18T22:02:08.690991Z',
   'updateTime': '2023-09-18T22:02:28.806318Z',
   'tuningTask': {'startTime': '2023-09-18T22:02:09.161100369Z',
    'completeTime': '2023-09-18T22:02:28.806318Z',
    'snapshots': [{'step': 1,
      'meanLoss': 7.2774773,
      'computeTime': '2023-09-18T22:02:12.453056368Z'},
     {'step': 2,
      'meanLoss': 6.1902447,
      'computeTime': '2023-09-18T22:02:13.789508217Z'},
     {'step': 3,
      'meanLoss': 5.5545835,
      'computeTime': '2023-09-18T22:02:15.136220505Z'},
     {'step': 4,
      'epoch': 1,
      'meanLoss': 7.9237704,
      'computeTime': '2023-09-18T22:02:16.474358517Z'},
     {'step': 5,
      'epoch': 1,
      'meanLoss': 7.6770706,
      'computeTime': '2023-09-18T22:02:17.758261108Z'},
     {'step': 6,
      'epoch': 1,
      'meanLoss': 7.378622,
      'computeTime': '2023-09-18T22:02:19.114072224Z'},
     {'step': 7,
      'epoch': 1,
      'meanLoss': 4.485537,
      'computeTime': '2023-09-18T22:02:20.927434115Z'},
     {'step': 8,
      'epoch': 2,
      'meanLoss': 6.815181,
      'computeTime': '2023-09-18T22:02:22.267906011Z'},
     {'step': 9,
      'epoch': 2,
      'meanLoss': 6.411363,
      'computeTime': '2023-09-18T22:02:24.078114085Z'},
     {'step': 10,
      'epoch': 2,
      'meanLoss': 8.585093,
      'computeTime': '2023-09-18T22:02:25.441598938Z'},
     {'step': 11,
      'epoch': 2,
      'meanLoss': 4.901249,
      'computeTime': '2023-09-18T22:02:27.108985392Z'},
     {'step': 12,
      'epoch': 3,
      'meanLoss': 7.073003,
      'computeTime': '2023-09-18T22:02:28.441662034Z'}],
    'hyperparameters': {'epochCount': 3,
     'batchSize': 4,
     'learningRate': 0.001} },
   'temperature': 0.7,
   'topP': 0.95,
   'topK': 40},
  {'name': 'tunedModels/number-generator-model-kctlevca1g3q',
   'baseModel': 'models/text-bison-tuning-test',
   'displayName': 'number generator model',
   'state': 'ACTIVE',
   'createTime': '2023-09-18T23:43:21.461545Z',
   'updateTime': '2023-09-18T23:43:49.205493Z',
   'tuningTask': {'startTime': '2023-09-18T23:43:21.542403958Z',
    'completeTime': '2023-09-18T23:43:49.205493Z',
    'snapshots': [{'step': 1,
      'meanLoss': 7.342065,
      'computeTime': '2023-09-18T23:43:23.356271969Z'},
     {'step': 2,
      'meanLoss': 7.255807,
      'computeTime': '2023-09-18T23:43:24.620248223Z'},
     {'step': 3,
      'meanLoss': 5.4591417,
      'computeTime': '2023-09-18T23:43:25.854505395Z'},
     {'step': 4,
      'meanLoss': 6.968665,
      'computeTime': '2023-09-18T23:43:27.138260198Z'},
     {'step': 5,
      'meanLoss': 4.578809,
      'computeTime': '2023-09-18T23:43:28.404943274Z'},
     {'step': 6,
      'meanLoss': 6.4862137,
      'computeTime': '2023-09-18T23:43:29.631624883Z'},
     {'step': 7,
      'meanLoss': 9.781939,
      'computeTime': '2023-09-18T23:43:30.801341449Z'},
     {'step': 8,
      'epoch': 1,
      'meanLoss': 5.990006,
      'computeTime': '2023-09-18T23:43:31.854703315Z'},
     {'step': 9,
      'epoch': 1,
      'meanLoss': 8.846312,
      'computeTime': '2023-09-18T23:43:33.075785103Z'},
     {'step': 10,
      'epoch': 1,
      'meanLoss': 6.1585655,
      'computeTime': '2023-09-18T23:43:34.310432174Z'},
     {'step': 11,
      'epoch': 1,
      'meanLoss': 4.7877502,
      'computeTime': '2023-09-18T23:43:35.381582526Z'},
     {'step': 12,
      'epoch': 1,
      'meanLoss': 9.660514,
      'computeTime': '2023-09-18T23:43:36.445446408Z'},
     {'step': 13,
      'epoch': 1,
      'meanLoss': 5.6482882,
      'computeTime': '2023-09-18T23:43:37.603237821Z'},
     {'step': 14,
      'epoch': 1,
      'meanLoss': 3.162092,
      'computeTime': '2023-09-18T23:43:38.671463397Z'},
     {'step': 15,
      'epoch': 2,
      'meanLoss': 6.322996,
      'computeTime': '2023-09-18T23:43:39.769742201Z'},
     {'step': 16,
      'epoch': 2,
      'meanLoss': 6.781,
      'computeTime': '2023-09-18T23:43:40.985967994Z'},
     {'step': 17,
      'epoch': 2,
      'meanLoss': 5.136773,
      'computeTime': '2023-09-18T23:43:42.235469710Z'},
     {'step': 18,
      'epoch': 2,
      'meanLoss': 7.2091155,
      'computeTime': '2023-09-18T23:43:43.415178581Z'},
     {'step': 19,
      'epoch': 2,
      'meanLoss': 7.7508755,
      'computeTime': '2023-09-18T23:43:44.775221774Z'},
     {'step': 20,
      'epoch': 2,
      'meanLoss': 8.144815,
      'computeTime': '2023-09-18T23:43:45.788824334Z'},
     {'step': 21,
      'epoch': 2,
      'meanLoss': 5.485137,
      'computeTime': '2023-09-18T23:43:46.812663998Z'},
     {'step': 22,
      'epoch': 2,
      'meanLoss': 3.709197,
      'computeTime': '2023-09-18T23:43:47.971764087Z'},
     {'step': 23,
      'epoch': 3,
      'meanLoss': 6.0069466,
      'computeTime': '2023-09-18T23:43:49.004191079Z'}],
    'hyperparameters': {'epochCount': 3,
     'batchSize': 2,
     'learningRate': 0.001} },
   'temperature': 0.7,
   'topP': 0.95,
   'topK': 40},
  {'name': 'tunedModels/my-display-name-81-r9wcuda14lyy',
   'baseModel': 'models/text-bison-tuning-test',
   'displayName': 'my display name 81',
   'state': 'ACTIVE',
   'createTime': '2023-09-18T23:52:06.980185Z',
   'updateTime': '2023-09-18T23:52:26.679601Z',
   'tuningTask': {'startTime': '2023-09-18T23:52:07.616953503Z',
    'completeTime': '2023-09-18T23:52:26.679601Z',
    'snapshots': [{'step': 1,
      'meanLoss': 7.2774773,
      'computeTime': '2023-09-18T23:52:10.278936662Z'},
     {'step': 2,
      'meanLoss': 6.2793097,
      'computeTime': '2023-09-18T23:52:11.630844790Z'},
     {'step': 3,
      'meanLoss': 5.540499,
      'computeTime': '2023-09-18T23:52:13.027840389Z'},
     {'step': 4,
      'epoch': 1,
      'meanLoss': 7.977523,
      'computeTime': '2023-09-18T23:52:14.368199020Z'},
     {'step': 5,
      'epoch': 1,
      'meanLoss': 7.6197805,
      'computeTime': '2023-09-18T23:52:15.872428752Z'},
     {'step': 6,
      'epoch': 1,
      'meanLoss': 7.3851357,
      'computeTime': '2023-09-18T23:52:17.213094182Z'},
     {'step': 7,
      'epoch': 1,
      'meanLoss': 4.5342345,
      'computeTime': '2023-09-18T23:52:19.090698421Z'},
     {'step': 8,
      'epoch': 2,
      'meanLoss': 6.8603754,
      'computeTime': '2023-09-18T23:52:20.494844731Z'},
     {'step': 9,
      'epoch': 2,
      'meanLoss': 6.418575,
      'computeTime': '2023-09-18T23:52:21.815997555Z'},
     {'step': 10,
      'epoch': 2,
      'meanLoss': 8.659064,
      'computeTime': '2023-09-18T23:52:23.524287192Z'},
     {'step': 11,
      'epoch': 2,
      'meanLoss': 4.856765,
      'computeTime': '2023-09-18T23:52:24.864661291Z'},
     {'step': 12,
      'epoch': 3,
      'meanLoss': 7.1078596,
      'computeTime': '2023-09-18T23:52:26.225055381Z'}],
    'hyperparameters': {'epochCount': 3,
     'batchSize': 4,
     'learningRate': 0.001} },
   'temperature': 0.7,
   'topP': 0.95,
   'topK': 40},
  {'name': 'tunedModels/number-generator-model-w1eabln5adwp',
   'baseModel': 'models/text-bison-tuning-test',
   'displayName': 'number generator model',
   'state': 'ACTIVE',
   'createTime': '2023-09-19T19:29:08.622497Z',
   'updateTime': '2023-09-19T19:29:46.063853Z',
   'tuningTask': {'startTime': '2023-09-19T19:29:08.806930486Z',
    'completeTime': '2023-09-19T19:29:46.063853Z',
    'snapshots': [{'step': 1,
      'meanLoss': 7.342065,
      'computeTime': '2023-09-19T19:29:13.023811994Z'},
     {'step': 2,
      'meanLoss': 7.1960244,
      'computeTime': '2023-09-19T19:29:14.844046282Z'},
     {'step': 3,
      'meanLoss': 5.480289,
      'computeTime': '2023-09-19T19:29:16.596884354Z'},
     {'step': 4,
      'meanLoss': 6.851822,
      'computeTime': '2023-09-19T19:29:17.741735378Z'},
     {'step': 5,
      'meanLoss': 4.5535283,
      'computeTime': '2023-09-19T19:29:18.914760812Z'},
     {'step': 6,
      'meanLoss': 6.449012,
      'computeTime': '2023-09-19T19:29:20.053316042Z'},
     {'step': 7,
      'meanLoss': 9.842458,
      'computeTime': '2023-09-19T19:29:21.371286675Z'},
     {'step': 8,
      'epoch': 1,
      'meanLoss': 5.9831877,
      'computeTime': '2023-09-19T19:29:22.915277044Z'},
     {'step': 9,
      'epoch': 1,
      'meanLoss': 8.936815,
      'computeTime': '2023-09-19T19:29:24.666461680Z'},
     {'step': 10,
      'epoch': 1,
      'meanLoss': 6.14651,
      'computeTime': '2023-09-19T19:29:26.793310451Z'},
     {'step': 11,
      'epoch': 1,
      'meanLoss': 4.853589,
      'computeTime': '2023-09-19T19:29:28.328297535Z'},
     {'step': 12,
      'epoch': 1,
      'meanLoss': 9.6831045,
      'computeTime': '2023-09-19T19:29:29.501236840Z'},
     {'step': 13,
      'epoch': 1,
      'meanLoss': 5.706586,
      'computeTime': '2023-09-19T19:29:30.612807978Z'},
     {'step': 14,
      'epoch': 1,
      'meanLoss': 3.276942,
      'computeTime': '2023-09-19T19:29:31.928747103Z'},
     {'step': 15,
      'epoch': 2,
      'meanLoss': 6.1736736,
      'computeTime': '2023-09-19T19:29:33.588699180Z'},
     {'step': 16,
      'epoch': 2,
      'meanLoss': 6.857398,
      'computeTime': '2023-09-19T19:29:35.239083809Z'},
     {'step': 17,
      'epoch': 2,
      'meanLoss': 5.098094,
      'computeTime': '2023-09-19T19:29:37.000705047Z'},
     {'step': 18,
      'epoch': 2,
      'meanLoss': 7.27724,
      'computeTime': '2023-09-19T19:29:38.532313231Z'},
     {'step': 19,
      'epoch': 2,
      'meanLoss': 7.6310735,
      'computeTime': '2023-09-19T19:29:39.696034301Z'},
     {'step': 20,
      'epoch': 2,
      'meanLoss': 8.152623,
      'computeTime': '2023-09-19T19:29:40.803342042Z'},
     {'step': 21,
      'epoch': 2,
      'meanLoss': 5.451577,
      'computeTime': '2023-09-19T19:29:42.445788199Z'},
     {'step': 22,
      'epoch': 2,
      'meanLoss': 3.7990716,
      'computeTime': '2023-09-19T19:29:43.866737307Z'},
     {'step': 23,
      'epoch': 3,
      'meanLoss': 6.120624,
      'computeTime': '2023-09-19T19:29:45.599248553Z'}],
    'hyperparameters': {'epochCount': 3,
     'batchSize': 2,
     'learningRate': 0.001} },
   'temperature': 0.7,
   'topP': 0.95,
   'topK': 40}]}

建立經過調整的模型

與 Curl 範例相同,您是透過 training_data 欄位傳入資料集。

operation = requests.post(
    url = f'{base_url}/v1beta3/tunedModels',
    headers=headers,
    json= {
        "display_name": "number generator",
        "base_model": "models/text-bison-001",
        "tuning_task": {
          "hyperparameters": {
            "batch_size": 4,
            "learning_rate": 0.001,
            "epoch_count":3,
          },
          "training_data": {
            "examples": {
              "examples": [
                {
                    'text_input': '1',
                    'output': '2',
                },{
                    'text_input': '3',
                    'output': '4',
                },{
                    'text_input': '-3',
                    'output': '-2',
                },{
                    'text_input': 'twenty two',
                    'output': 'twenty three',
                },{
                    'text_input': 'two hundred',
                    'output': 'two hundred one',
                },{
                    'text_input': 'ninety nine',
                    'output': 'one hundred',
                },{
                    'text_input': '8',
                    'output': '9',
                },{
                    'text_input': '-98',
                    'output': '-97',
                },{
                    'text_input': '1,000',
                    'output': '1,001',
                },{
                    'text_input': '10,100,000',
                    'output': '10,100,001',
                },{
                    'text_input': 'thirteen',
                    'output': 'fourteen',
                },{
                    'text_input': 'eighty',
                    'output': 'eighty one',
                },{
                    'text_input': 'one',
                    'output': 'two',
                },{
                    'text_input': 'three',
                    'output': 'four',
                },{
                    'text_input': 'seven',
                    'output': 'eight',
                }
              ]
            }
          }
        }
      }
)
operation
<Response [200]>
operation.json()
{'name': 'tunedModels/number-generator-ncqqnysl74dt/operations/qqlbwzfyzn0k',
 'metadata': {'@type': 'type.googleapis.com/google.ai.generativelanguage.v1beta3.CreateTunedModelMetadata',
  'totalSteps': 12,
  'tunedModel': 'tunedModels/number-generator-ncqqnysl74dt'} }

使用調整後模型的名稱設定變數,以便用於其他呼叫。

name=operation.json()["metadata"]["tunedModel"]
name
'tunedModels/number-generator-ncqqnysl74dt'

取得調整後模型狀態

如要查看調整工作的進度,請查看狀態欄位。CREATING 表示調整工作仍在進行,ACTIVE 表示已訓練完成,且經過調整的模型可供使用。

tuned_model = requests.get(
    url = f'{base_url}/v1beta3/{name}',
    headers=headers,
)
tuned_model.json()
{'name': 'tunedModels/number-generator-ncqqnysl74dt',
 'baseModel': 'models/text-bison-001',
 'displayName': 'number generator',
 'state': 'CREATING',
 'createTime': '2023-09-19T19:56:25.999303Z',
 'updateTime': '2023-09-19T19:56:25.999303Z',
 'tuningTask': {'startTime': '2023-09-19T19:56:26.297862545Z',
  'hyperparameters': {'epochCount': 3, 'batchSize': 4, 'learningRate': 0.001} },
 'temperature': 0.7,
 'topP': 0.95,
 'topK': 40}

以下程式碼會每 5 秒檢查一次狀態欄位,直到狀態不再處於 CREATING 狀態為止。

import time
import pprint

op_json = operation.json()
response = op_json.get('response')
error = op_json.get('error')

while response is None and error is None:
    time.sleep(31)

    operation = requests.get(
        url = f'{base_url}/v1/{op_json["name"]}',
        headers=headers,
    )

    op_json = operation.json()
    response = op_json.get('response')
    error = op_json.get('error')

    percent = op_json['metadata'].get('completedPercent')
    if percent is not None:
      print(f"{percent:.2f}% - {op_json['metadata']['snapshots'][-1]}")
      print()

if error is not None:
    raise Exception(error)
21.28% - {'step': 40, 'epoch': 10, 'meanLoss': 2.4871845, 'computeTime': '2023-09-20T00:23:55.255785843Z'}

21.28% - {'step': 40, 'epoch': 10, 'meanLoss': 2.4871845, 'computeTime': '2023-09-20T00:23:55.255785843Z'}

43.09% - {'step': 81, 'epoch': 21, 'meanLoss': 0.032220088, 'computeTime': '2023-09-20T00:24:56.302837803Z'}

43.09% - {'step': 81, 'epoch': 21, 'meanLoss': 0.032220088, 'computeTime': '2023-09-20T00:24:56.302837803Z'}

63.83% - {'step': 120, 'epoch': 32, 'meanLoss': 0.0030430648, 'computeTime': '2023-09-20T00:25:57.228615435Z'}

63.83% - {'step': 120, 'epoch': 32, 'meanLoss': 0.0030430648, 'computeTime': '2023-09-20T00:25:57.228615435Z'}

85.11% - {'step': 160, 'epoch': 42, 'meanLoss': -1.1145603e-06, 'computeTime': '2023-09-20T00:26:57.819011896Z'}

100.00% - {'step': 188, 'epoch': 50, 'meanLoss': 0.00040101097, 'computeTime': '2023-09-20T00:27:40.024132813Z'}

執行推論

調整工作完成後,即可使用基本文字模型產生文字。

import time

m = requests.post(
    url = f'{base_url}/v1beta3/{name}:generateText',
    headers=headers,
    json= {
         "prompt": {
              "text": "9"
              },
    })
import pprint
print(m.json()['candidates'][0]['output'])
9

模型的輸出結果不一定正確。如果經過調整的模型效能未達要求標準,您可以嘗試新增更多高品質的範例、調整超參數,或是在範例中加入前置參數。

後續步驟