REST API: Modellabstimmung

Auf ai.google.dev ansehen In Google Colab ausführen Quelle auf GitHub ansehen

In diesem Notebook erfahren Sie, wie Sie den Gemini API-Abstimmungsdienst mithilfe von curl-Befehlen oder der Python Request API zum Aufrufen der Gemini API verwenden. Hier erfahren Sie, wie Sie das Textmodell hinter dem Textgenerierungsdienst der Gemini API abstimmen.

Einrichtung

Authentifizieren

Mit der Gemini API können Sie Modelle anhand Ihrer eigenen Daten abstimmen. Da es sich um Ihre Daten und Ihre abgestimmten Modelle handelt, ist eine strengere Zugriffssteuerung erforderlich als API-Schlüssel.

Bevor Sie diese Anleitung ausführen können, müssen Sie OAuth für Ihr Projekt einrichten.

Am einfachsten lässt sich die Einrichtung in Colab einrichten, indem Sie den Inhalt Ihrer client_secret.json-Datei in das Feld „Secrets Manager“ von Colab mit dem Secret-Namen CLIENT_SECRET kopieren (unter dem Schlüsselsymbol im linken Bereich).

Dieser gcloud-Befehl wandelt die Datei client_secret.json in Anmeldedaten um, die zur Authentifizierung beim Dienst verwendet werden können.

try:
  from google.colab import userdata
  import pathlib
  pathlib.Path('client_secret.json').write_text(userdata.get('CLIENT_SECRET'))

  # Use `--no-browser` in colab
  !gcloud auth application-default login --no-browser --client-id-file client_secret.json --scopes='https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/generative-language.tuning'
except ImportError:
  !gcloud auth application-default login --client-id-file client_secret.json --scopes='https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/generative-language.tuning'
You are authorizing client libraries without access to a web browser. Please run the following command on a machine with a web browser and copy its output back here. Make sure the installed gcloud version is 372.0.0 or newer.

gcloud auth application-default login --remote-bootstrap="https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=87071151422-n1a3cb6c7fvkfg4gmhdtmn5ulol2l4be.apps.googleusercontent.com&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fgenerative-language.tuning&state=QIyNibWSaTIsozjmvZEkVBo6EcoW0G&access_type=offline&code_challenge=76c1ZiGvKN8cvlYfj3BmbCwE4e7tvrlwaX3REUX25gY&code_challenge_method=S256&token_usage=remote"


Enter the output of the above command: https://localhost:8085/?state=QIyNibWSaTIsozjmvZEkVBo6EcoW0G&code=4/0AeaYSHBKrY911S466QjKQIFODoOPXlO1mWyTYYdrbELIDV6Hw2DKRAyro62BugroSvIWsA&scope=https://www.googleapis.com/auth/cloud-platform%20https://www.googleapis.com/auth/generative-language.tuning

Credentials saved to file: [/content/.config/application_default_credentials.json]

These credentials will be used by any library that requests Application Default Credentials (ADC).

REST API mit CURL aufrufen

Dieser Abschnitt enthält curl-Beispielanweisungen zum Aufrufen der REST API. Sie erfahren, wie Sie einen Abstimmungsjob erstellen, seinen Status prüfen und anschließend einen Inferenzaufruf durchführen.

Variablen festlegen

Legen Sie Variablen für wiederkehrende Werte fest, die für die restlichen REST API-Aufrufe verwendet werden. Der Code legt mithilfe der Python-Bibliothek os Umgebungsvariablen fest, auf die in allen Codezellen zugegriffen werden kann.

Dies gilt speziell für die Colab-Notebook-Umgebung. Der Code in der nächsten Codezelle entspricht der Ausführung der folgenden Befehle in einem Bash-Terminal.

export access_token=$(gcloud auth application-default print-access-token)
export project_id=my-project-id
export base_url=https://generativelanguage.googleapis.com
import os

access_token = !gcloud auth application-default print-access-token
access_token = '\n'.join(access_token)

os.environ['access_token'] = access_token
os.environ['project_id'] = "[Enter your project-id here]"
os.environ['base_url'] = "https://generativelanguage.googleapis.com"

Abgestimmte Modelle auflisten

Überprüfen Sie die Einrichtung der Authentifizierung, indem Sie die aktuell verfügbaren abgestimmten Modelle auflisten.


curl -X GET ${base_url}/v1beta/tunedModels \
    -H "Content-Type: application/json" \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}"

Abgestimmtes Modell erstellen

Zum Erstellen eines abgestimmten Modells müssen Sie Ihr Dataset im Feld training_data an das Modell übergeben.

In diesem Beispiel stimmen Sie ein Modell so ab, dass die nächste Zahl in der Sequenz generiert wird. Wenn die Eingabe beispielsweise 1 ist, sollte das Modell 2 ausgeben. Wenn die Eingabe one hundred ist, sollte die Ausgabe one hundred one sein.


curl -X POST $base_url/v1beta/tunedModels \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \
    -d '
      {
        "display_name": "number generator model",
        "base_model": "models/gemini-1.0-pro-001",
        "tuning_task": {
          "hyperparameters": {
            "batch_size": 2,
            "learning_rate": 0.001,
            "epoch_count":5,
          },
          "training_data": {
            "examples": {
              "examples": [
                {
                    "text_input": "1",
                    "output": "2",
                },{
                    "text_input": "3",
                    "output": "4",
                },{
                    "text_input": "-3",
                    "output": "-2",
                },{
                    "text_input": "twenty two",
                    "output": "twenty three",
                },{
                    "text_input": "two hundred",
                    "output": "two hundred one",
                },{
                    "text_input": "ninety nine",
                    "output": "one hundred",
                },{
                    "text_input": "8",
                    "output": "9",
                },{
                    "text_input": "-98",
                    "output": "-97",
                },{
                    "text_input": "1,000",
                    "output": "1,001",
                },{
                    "text_input": "10,100,000",
                    "output": "10,100,001",
                },{
                    "text_input": "thirteen",
                    "output": "fourteen",
                },{
                    "text_input": "eighty",
                    "output": "eighty one",
                },{
                    "text_input": "one",
                    "output": "two",
                },{
                    "text_input": "three",
                    "output": "four",
                },{
                    "text_input": "seven",
                    "output": "eight",
                }
              ]
            }
          }
        }
      }' | tee tunemodel.json
{
  "name": "tunedModels/number-generator-model-dzlmi0gswwqb/operations/bvl8dymw0fhw",
  "metadata": {
    "@type": "type.googleapis.com/google.ai.generativelanguage.v1beta.CreateTunedModelMetadata",
    "totalSteps": 38,
    "tunedModel": "tunedModels/number-generator-model-dzlmi0gswwqb"
  }
}
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  2280    0   296  100  1984    611   4098 --:--:-- --:--:-- --:--:--  4720

Status des abgestimmten Modells abrufen

Der Status des Modells wird während des Trainings auf CREATING gesetzt und ändert sich nach Abschluss in ACTIVE.

Im Folgenden finden Sie Python-Code, mit dem der generierte Modellname aus der JSON-Antwort geparst wird. Wenn Sie die Abfrage in einem Terminal ausführen, können Sie versuchen, die Antwort mit einem Bash-JSON-Parser zu parsen.

import json

first_page = json.load(open('tunemodel.json'))
os.environ['modelname'] = first_page['metadata']['tunedModel']

print(os.environ['modelname'])
tunedModels/number-generator-model-dzlmi0gswwqb

Stellen Sie eine weitere GET-Anfrage mit dem Modellnamen, um die Modellmetadaten abzurufen, die das Statusfeld enthalten.


curl -X GET ${base_url}/v1beta/${modelname} \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" | grep state
"state": "ACTIVE",
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5921    0  5921    0     0  13164      0 --:--:-- --:--:-- --:--:-- 13157

Inferenz ausführen

Sobald der Abstimmungsjob abgeschlossen ist, können Sie ihn zum Generieren von Text mit dem Textdienst verwenden. Versuchen Sie, eine römische Zahl einzugeben, z. B. 63 (LXIII)


curl -X POST $base_url/v1beta/$modelname:generateContent \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \
    -d '{
        "contents": [{
        "parts": [{
          "text": "LXIII"
          }]
        }]
        }' 2> /dev/null
{
  "candidates": [
    {
      "content": {
        "parts": [
          {
            "text": "LXIV"
          }
        ],
        "role": "model"
      },
      "finishReason": "STOP",
      "index": 0,
      "safetyRatings": [
        {
          "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
          "probability": "NEGLIGIBLE"
        },
        {
          "category": "HARM_CATEGORY_HATE_SPEECH",
          "probability": "NEGLIGIBLE"
        },
        {
          "category": "HARM_CATEGORY_HARASSMENT",
          "probability": "NEGLIGIBLE"
        },
        {
          "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
          "probability": "NEGLIGIBLE"
        }
      ]
    }
  ],
  "promptFeedback": {
    "safetyRatings": [
      {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_HARASSMENT",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "probability": "NEGLIGIBLE"
      }
    ]
  }
}

Die Ausgabe des Modells ist möglicherweise nicht korrekt. Wenn das abgestimmte Modell nicht Ihren Anforderungen entspricht, können Sie versuchen, mehr hochwertige Beispiele hinzuzufügen, die Hyperparameter zu optimieren oder Ihren Beispielen eine Präambel hinzuzufügen. Sie können sogar ein weiteres abgestimmtes Modell erstellen, das auf dem ersten abgestimmten Modell basiert.

Weitere Tipps zur Verbesserung der Leistung finden Sie im Leitfaden zur Abstimmung.

REST API mit Python-Anfragen aufrufen

Sie können die restliche API mit jeder Bibliothek aufrufen, mit der Sie HTTP-Anfragen senden können. In den nächsten Beispielen wird die Python-Anfragebibliothek verwendet und einige der fortgeschritteneren Funktionen veranschaulicht.

Variablen festlegen

access_token = !gcloud auth application-default print-access-token
access_token = '\n'.join(access_token)

project = '[Enter your project-id here]'
base_url = "https://generativelanguage.googleapis.com"

Importieren Sie die requests-Bibliothek.

import requests
import json

Abgestimmte Modelle auflisten

Überprüfen Sie die Einrichtung der Authentifizierung, indem Sie die aktuell verfügbaren abgestimmten Modelle auflisten.

headers={
  'Authorization': 'Bearer ' + access_token,
  'Content-Type': 'application/json',
  'x-goog-user-project': project
}

result = requests.get(
  url=f'{base_url}/v1beta/tunedModels',
  headers = headers,
)
result.json()

Abgestimmtes Modell erstellen

Wie im Curl-Beispiel übergeben Sie das Dataset über das Feld training_data.

operation = requests.post(
    url = f'{base_url}/v1beta/tunedModels',
    headers=headers,
    json= {
        "display_name": "number generator",
        "base_model": "models/gemini-1.0-pro-001",
        "tuning_task": {
          "hyperparameters": {
            "batch_size": 4,
            "learning_rate": 0.001,
            "epoch_count":5,
          },
          "training_data": {
            "examples": {
              "examples": [
                {
                    'text_input': '1',
                    'output': '2',
                },{
                    'text_input': '3',
                    'output': '4',
                },{
                    'text_input': '-3',
                    'output': '-2',
                },{
                    'text_input': 'twenty two',
                    'output': 'twenty three',
                },{
                    'text_input': 'two hundred',
                    'output': 'two hundred one',
                },{
                    'text_input': 'ninety nine',
                    'output': 'one hundred',
                },{
                    'text_input': '8',
                    'output': '9',
                },{
                    'text_input': '-98',
                    'output': '-97',
                },{
                    'text_input': '1,000',
                    'output': '1,001',
                },{
                    'text_input': '10,100,000',
                    'output': '10,100,001',
                },{
                    'text_input': 'thirteen',
                    'output': 'fourteen',
                },{
                    'text_input': 'eighty',
                    'output': 'eighty one',
                },{
                    'text_input': 'one',
                    'output': 'two',
                },{
                    'text_input': 'three',
                    'output': 'four',
                },{
                    'text_input': 'seven',
                    'output': 'eight',
                }
              ]
            }
          }
        }
      }
)
operation
<Response [200]>
operation.json()
{'name': 'tunedModels/number-generator-wl1qr34x2py/operations/41vni3zk0a47',
 'metadata': {'@type': 'type.googleapis.com/google.ai.generativelanguage.v1beta.CreateTunedModelMetadata',
  'totalSteps': 19,
  'tunedModel': 'tunedModels/number-generator-wl1qr34x2py'} }

Legen Sie eine Variable mit dem Namen Ihres abgestimmten Modells fest, die für die restlichen Aufrufe verwendet werden soll.

name=operation.json()["metadata"]["tunedModel"]
name
'tunedModels/number-generator-wl1qr34x2py'

Status des abgestimmten Modells abrufen

Sie können den Fortschritt des Abstimmungsjobs im Statusfeld prüfen. CREATING bedeutet, dass der Abstimmungsjob noch läuft, und ACTIVE bedeutet, dass die Trainings abgeschlossen sind und das abgestimmte Modell einsatzbereit ist.

tuned_model = requests.get(
    url = f'{base_url}/v1beta/{name}',
    headers=headers,
)
tuned_model.json()

Mit dem folgenden Code wird das Statusfeld alle 5 Sekunden überprüft, bis es nicht mehr den Status CREATING hat.

import time
import pprint

op_json = operation.json()
response = op_json.get('response')
error = op_json.get('error')

while response is None and error is None:
    time.sleep(5)

    operation = requests.get(
        url = f'{base_url}/v1/{op_json["name"]}',
        headers=headers,
    )

    op_json = operation.json()
    response = op_json.get('response')
    error = op_json.get('error')

    percent = op_json['metadata'].get('completedPercent')
    if percent is not None:
      print(f"{percent:.2f}% - {op_json['metadata']['snapshots'][-1]}")
      print()

if error is not None:
    raise Exception(error)
100.00% - {'step': 19, 'epoch': 5, 'meanLoss': 1.402067, 'computeTime': '2024-03-14T15:11:23.766989274Z'}

Inferenz ausführen

Sobald der Abstimmungsjob abgeschlossen ist, können Sie damit Text auf die gleiche Weise generieren wie mit dem Basistextmodell. Versuchen Sie, eine japanische Zahl einzugeben, z. B. 6 (六).

import time

m = requests.post(
    url = f'{base_url}/v1beta/{name}:generateContent',
    headers=headers,
    json= {
         "contents": [{
             "parts": [{
                 "text": "六"
             }]
          }]
    })
import pprint
pprint.pprint(m.json())
{'candidates': [{'content': {'parts': [{'text': '七'}], 'role': 'model'},
                 'finishReason': 'STOP',
                 'index': 0,
                 'safetyRatings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
                                    'probability': 'NEGLIGIBLE'},
                                   {'category': 'HARM_CATEGORY_HATE_SPEECH',
                                    'probability': 'NEGLIGIBLE'},
                                   {'category': 'HARM_CATEGORY_HARASSMENT',
                                    'probability': 'LOW'},
                                   {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
                                    'probability': 'NEGLIGIBLE'}]}],
 'promptFeedback': {'safetyRatings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
                                       'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_HATE_SPEECH',
                                       'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_HARASSMENT',
                                       'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
                                       'probability': 'NEGLIGIBLE'}]} }

Die Ausgabe des Modells ist möglicherweise nicht korrekt. Wenn das abgestimmte Modell nicht Ihren Anforderungen entspricht, können Sie versuchen, mehr hochwertige Beispiele hinzuzufügen, die Hyperparameter zu optimieren oder Ihren Beispielen eine Präambel hinzuzufügen.

Fazit

Obwohl die Trainingsdaten keinen Verweis auf römische oder japanische Ziffern enthielten, konnte das Modell lange nach der Feinabstimmung verallgemeinern. Auf diese Weise können Sie Modelle an Ihre Anwendungsfälle anpassen.

Nächste Schritte

Informationen zur Verwendung des Abstimmungsdienstes mithilfe des Python SDK für die Gemini API finden Sie in der Kurzanleitung zur Abstimmung mit Python. Informationen zur Verwendung anderer Dienste in der Gemini API finden Sie in der Python-Kurzanleitung.