Tutorial fine-tuning

Lihat di ai.google.dev Berjalan di Google Colab Lihat sumber di GitHub

Di notebook ini, Anda akan mempelajari cara mulai menggunakan layanan penyesuaian Gemini API menggunakan perintah CURL atau Python request API untuk memanggil Gemini API. Di sini, Anda akan mempelajari cara menyesuaikan model teks di balik layanan pembuatan teks Gemini API.

Menyiapkan autentikasi

Gemini API memungkinkan Anda menyesuaikan model menggunakan data Anda sendiri. Karena ini adalah data Anda dan model yang Anda sesuaikan, hal ini memerlukan kontrol akses yang lebih ketat daripada yang dapat diberikan oleh Kunci API.

Sebelum dapat menjalankan tutorial ini, Anda harus menyiapkan OAuth untuk project Anda.

Di Colab, cara termudah untuk mendapatkan penyiapan adalah menyalin konten file client_secret.json Anda ke "Secret manager" Colab (di bawah ikon kunci di panel kiri) dengan nama rahasia CLIENT_SECRET.

Perintah gcloud ini mengubah file client_secret.json menjadi kredensial yang dapat digunakan untuk melakukan autentikasi dengan layanan.

try:
  from google.colab import userdata
  import pathlib
  pathlib.Path('client_secret.json').write_text(userdata.get('CLIENT_SECRET'))

  # Use `--no-browser` in colab
  !gcloud auth application-default login --no-browser --client-id-file client_secret.json --scopes='https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/generative-language.tuning'
except ImportError:
  !gcloud auth application-default login --client-id-file client_secret.json --scopes='https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/generative-language.tuning'
You are authorizing client libraries without access to a web browser. Please run the following command on a machine with a web browser and copy its output back here. Make sure the installed gcloud version is 372.0.0 or newer.

gcloud auth application-default login --remote-bootstrap="https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=87071151422-n1a3cb6c7fvkfg4gmhdtmn5ulol2l4be.apps.googleusercontent.com&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fgenerative-language.tuning&state=QIyNibWSaTIsozjmvZEkVBo6EcoW0G&access_type=offline&code_challenge=76c1ZiGvKN8cvlYfj3BmbCwE4e7tvrlwaX3REUX25gY&code_challenge_method=S256&token_usage=remote"

Enter the output of the above command: https://localhost:8085/?state=QIyNibWSaTIsozjmvZEkVBo6EcoW0G&code=4/0AeaYSHBKrY911S466QjKQIFODoOPXlO1mWyTYYdrbELIDV6Hw2DKRAyro62BugroSvIWsA&scope=https://www.googleapis.com/auth/cloud-platform%20https://www.googleapis.com/auth/generative-language.tuning

Credentials saved to file: [/content/.config/application_default_credentials.json]

These credentials will be used by any library that requests Application Default Credentials (ADC).

Menetapkan variabel

CURL

Menetapkan variabel untuk nilai berulang yang akan digunakan untuk panggilan REST API lainnya. Kode ini menggunakan library os Python untuk menetapkan variabel lingkungan yang dapat diakses di semua sel kode.

Setelan ini dikhususkan untuk lingkungan notebook Colab. Kode dalam sel kode berikutnya setara dengan menjalankan perintah berikut di terminal bash.

export access_token=$(gcloud auth application-default print-access-token)
export project_id=my-project-id
export base_url=https://generativelanguage.googleapis.com
import os

access_token = !gcloud auth application-default print-access-token
access_token = '\n'.join(access_token)

os.environ['access_token'] = access_token
os.environ['project_id'] = "[Enter your project-id here]"
os.environ['base_url'] = "https://generativelanguage.googleapis.com"

Python

access_token = !gcloud auth application-default print-access-token
access_token = '\n'.join(access_token)

project = '[Enter your project-id here]'
base_url = "https://generativelanguage.googleapis.com"

Impor library requests.

import requests
import json

Mencantumkan model yang telah disesuaikan

Verifikasi penyiapan autentikasi Anda dengan mencantumkan model yang telah disesuaikan yang tersedia.

CURL


curl -X GET ${base_url}/v1beta/tunedModels \
    -H "Content-Type: application/json" \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}"

Python

headers={
  'Authorization': 'Bearer ' + access_token,
  'Content-Type': 'application/json',
  'x-goog-user-project': project
}

result = requests.get(
  url=f'{base_url}/v1beta/tunedModels',
  headers = headers,
)
result.json()

Buat model yang di-tuning

Untuk membuat model yang disesuaikan, Anda harus meneruskan set data ke model tersebut di kolom training_data.

Untuk contoh ini, Anda akan menyesuaikan model untuk menghasilkan angka berikutnya dalam urutan. Misalnya, jika inputnya adalah 1, model akan menghasilkan 2. Jika inputnya adalah one hundred, output-nya harus one hundred one.

CURL


curl -X POST $base_url/v1beta/tunedModels \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \
    -d '
      {
        "display_name": "number generator model",
        "base_model": "models/gemini-1.0-pro-001",
        "tuning_task": {
          "hyperparameters": {
            "batch_size": 2,
            "learning_rate": 0.001,
            "epoch_count":5,
          },
          "training_data": {
            "examples": {
              "examples": [
                {
                    "text_input": "1",
                    "output": "2",
                },{
                    "text_input": "3",
                    "output": "4",
                },{
                    "text_input": "-3",
                    "output": "-2",
                },{
                    "text_input": "twenty two",
                    "output": "twenty three",
                },{
                    "text_input": "two hundred",
                    "output": "two hundred one",
                },{
                    "text_input": "ninety nine",
                    "output": "one hundred",
                },{
                    "text_input": "8",
                    "output": "9",
                },{
                    "text_input": "-98",
                    "output": "-97",
                },{
                    "text_input": "1,000",
                    "output": "1,001",
                },{
                    "text_input": "10,100,000",
                    "output": "10,100,001",
                },{
                    "text_input": "thirteen",
                    "output": "fourteen",
                },{
                    "text_input": "eighty",
                    "output": "eighty one",
                },{
                    "text_input": "one",
                    "output": "two",
                },{
                    "text_input": "three",
                    "output": "four",
                },{
                    "text_input": "seven",
                    "output": "eight",
                }
              ]
            }
          }
        }
      }' | tee tunemodel.json
{
"name": "tunedModels/number-generator-model-dzlmi0gswwqb/operations/bvl8dymw0fhw",
"metadata": {
  "@type": "type.googleapis.com/google.ai.generativelanguage.v1beta.CreateTunedModelMetadata",
  "totalSteps": 38,
  "tunedModel": "tunedModels/number-generator-model-dzlmi0gswwqb"
}
}
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                              Dload  Upload   Total   Spent    Left  Speed
100  2280    0   296  100  1984    611   4098 --:--:-- --:--:-- --:--:--  4720

Python

operation = requests.post(
    url = f'{base_url}/v1beta/tunedModels',
    headers=headers,
    json= {
        "display_name": "number generator",
        "base_model": "models/gemini-1.0-pro-001",
        "tuning_task": {
          "hyperparameters": {
            "batch_size": 4,
            "learning_rate": 0.001,
            "epoch_count":5,
          },
          "training_data": {
            "examples": {
              "examples": [
                {
                    'text_input': '1',
                    'output': '2',
                },{
                    'text_input': '3',
                    'output': '4',
                },{
                    'text_input': '-3',
                    'output': '-2',
                },{
                    'text_input': 'twenty two',
                    'output': 'twenty three',
                },{
                    'text_input': 'two hundred',
                    'output': 'two hundred one',
                },{
                    'text_input': 'ninety nine',
                    'output': 'one hundred',
                },{
                    'text_input': '8',
                    'output': '9',
                },{
                    'text_input': '-98',
                    'output': '-97',
                },{
                    'text_input': '1,000',
                    'output': '1,001',
                },{
                    'text_input': '10,100,000',
                    'output': '10,100,001',
                },{
                    'text_input': 'thirteen',
                    'output': 'fourteen',
                },{
                    'text_input': 'eighty',
                    'output': 'eighty one',
                },{
                    'text_input': 'one',
                    'output': 'two',
                },{
                    'text_input': 'three',
                    'output': 'four',
                },{
                    'text_input': 'seven',
                    'output': 'eight',
                }
              ]
            }
          }
        }
      }
)
operation
<Response [200]>
operation.json()
{'name': 'tunedModels/number-generator-wl1qr34x2py/operations/41vni3zk0a47',
'metadata': {'@type': 'type.googleapis.com/google.ai.generativelanguage.v1beta.CreateTunedModelMetadata',
  'totalSteps': 19,
  'tunedModel': 'tunedModels/number-generator-wl1qr34x2py'} }

Tetapkan variabel dengan nama model yang telah Anda sesuaikan untuk digunakan pada panggilan lainnya.

name=operation.json()["metadata"]["tunedModel"]
name
'tunedModels/number-generator-wl1qr34x2py'

Mendapatkan status model yang di-tuning

Status model ditetapkan ke CREATING selama pelatihan dan akan berubah menjadi ACTIVE setelah pelatihan selesai.

CURL

Di bawah ini adalah sedikit kode Python untuk mengurai nama model yang dihasilkan dari JSON respons. Jika Anda menjalankannya di terminal, Anda dapat mencoba menggunakan parser JSON bash untuk mengurai respons.

import json

first_page = json.load(open('tunemodel.json'))
os.environ['modelname'] = first_page['metadata']['tunedModel']

print(os.environ['modelname'])
tunedModels/number-generator-model-dzlmi0gswwqb

Lakukan permintaan GET lainnya dengan nama model untuk mendapatkan metadata model yang menyertakan kolom status.


curl -X GET ${base_url}/v1beta/${modelname} \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" | grep state
"state": "ACTIVE",
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                              Dload  Upload   Total   Spent    Left  Speed
100  5921    0  5921    0     0  13164      0 --:--:-- --:--:-- --:--:-- 13157

Python

tuned_model = requests.get(
    url = f'{base_url}/v1beta/{name}',
    headers=headers,
)
tuned_model.json()

Kode di bawah memeriksa kolom status setiap 5 detik hingga tidak lagi dalam status CREATING.

import time
import pprint

op_json = operation.json()
response = op_json.get('response')
error = op_json.get('error')

while response is None and error is None:
    time.sleep(5)

    operation = requests.get(
        url = f'{base_url}/v1/{op_json["name"]}',
        headers=headers,
    )

    op_json = operation.json()
    response = op_json.get('response')
    error = op_json.get('error')

    percent = op_json['metadata'].get('completedPercent')
    if percent is not None:
      print(f"{percent:.2f}% - {op_json['metadata']['snapshots'][-1]}")
      print()

if error is not None:
    raise Exception(error)
100.00% - {'step': 19, 'epoch': 5, 'meanLoss': 1.402067, 'computeTime': '2024-03-14T15:11:23.766989274Z'}

Menjalankan inferensi

Setelah tugas penyesuaian selesai, Anda dapat menggunakannya untuk membuat teks dengan layanan teks.

CURL

Coba masukkan angka Romawi, misalnya, 63 (LXIII):


curl -X POST $base_url/v1beta/$modelname:generateContent \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \
    -d '{
        "contents": [{
        "parts": [{
          "text": "LXIII"
          }]
        }]
        }' 2> /dev/null
{
"candidates": [
  {
    "content": {
      "parts": [
        {
          "text": "LXIV"
        }
      ],
      "role": "model"
    },
    "finishReason": "STOP",
    "index": 0,
    "safetyRatings": [
      {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_HARASSMENT",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "probability": "NEGLIGIBLE"
      }
    ]
  }
],
"promptFeedback": {
  "safetyRatings": [
    {
      "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
      "probability": "NEGLIGIBLE"
    },
    {
      "category": "HARM_CATEGORY_HATE_SPEECH",
      "probability": "NEGLIGIBLE"
    },
    {
      "category": "HARM_CATEGORY_HARASSMENT",
      "probability": "NEGLIGIBLE"
    },
    {
      "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
      "probability": "NEGLIGIBLE"
    }
  ]
}
}

Output dari model Anda mungkin benar atau mungkin tidak. Jika model yang disesuaikan tidak berperforma sesuai standar yang diperlukan, Anda dapat mencoba menambahkan lebih banyak contoh berkualitas tinggi, menyesuaikan hyperparameter, atau menambahkan preamble ke contoh Anda. Anda bahkan dapat membuat model lain yang telah disesuaikan berdasarkan model pertama yang Anda buat.

Lihat panduan penyesuaian untuk mendapatkan panduan lebih lanjut tentang cara meningkatkan performa.

Python

Coba masukkan angka Jepang, misalnya, 6 (六):

import time

m = requests.post(
    url = f'{base_url}/v1beta/{name}:generateContent',
    headers=headers,
    json= {
        "contents": [{
            "parts": [{
                "text": "六"
            }]
          }]
    })
import pprint
pprint.pprint(m.json())
{'candidates': [{'content': {'parts': [{'text': '七'}], 'role': 'model'},
                'finishReason': 'STOP',
                'index': 0,
                'safetyRatings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
                                    'probability': 'NEGLIGIBLE'},
                                  {'category': 'HARM_CATEGORY_HATE_SPEECH',
                                    'probability': 'NEGLIGIBLE'},
                                  {'category': 'HARM_CATEGORY_HARASSMENT',
                                    'probability': 'LOW'},
                                  {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
                                    'probability': 'NEGLIGIBLE'}]}],
'promptFeedback': {'safetyRatings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
                                      'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_HATE_SPEECH',
                                      'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_HARASSMENT',
                                      'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
                                      'probability': 'NEGLIGIBLE'}]} }

Output dari model Anda mungkin benar atau mungkin tidak. Jika model yang disesuaikan tidak berperforma sesuai standar yang diperlukan, Anda dapat mencoba menambahkan lebih banyak contoh berkualitas tinggi, menyesuaikan hyperparameter, atau menambahkan preamble ke contoh Anda.

Kesimpulan

Meskipun data pelatihan tidak berisi referensi ke angka Romawi atau Jepang, model ini dapat melakukan generalisasi dengan baik setelah fine-tuning. Dengan cara ini, Anda dapat meningkatkan kualitas model agar sesuai dengan kasus penggunaan Anda.

Langkah berikutnya

Untuk mempelajari cara menggunakan layanan penyesuaian dengan bantuan Python SDK untuk Gemini API, buka panduan memulai penyesuaian dengan Python. Untuk mempelajari cara menggunakan layanan lain di Gemini API, buka panduan panduan memulai Python.