Instructivo de ajuste

Ver en ai.google.dev Ejecutar en Google Colab Ver el código fuente en GitHub

En este notebook, aprenderás a usar el servicio de ajuste de la API de Gemini usando comandos CURL o la API de solicitudes de Python para llamar a la API de Gemini. Aquí, aprenderás a ajustar el modelo de texto detrás del servicio de generación de texto de la API de Gemini.

Configura la autenticación

La API de Gemini te permite ajustar modelos con tus propios datos. Dado que se trata de tus datos y tus modelos ajustados, esto requiere controles de acceso más estrictos que los que pueden proporcionar las claves de API.

Antes de ejecutar este instructivo, deberás configurar OAuth en tu proyecto.

En Colab, la configuración más fácil es copiar el contenido del archivo client_secret.json en el "Administrador de secretos" de Colab (debajo del ícono de clave en el panel izquierdo) con el nombre del secreto CLIENT_SECRET.

Este comando de gcloud convierte el archivo client_secret.json en credenciales que se pueden usar para autenticarse con el servicio.

try:
  from google.colab import userdata
  import pathlib
  pathlib.Path('client_secret.json').write_text(userdata.get('CLIENT_SECRET'))

  # Use `--no-browser` in colab
  !gcloud auth application-default login --no-browser --client-id-file client_secret.json --scopes='https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/generative-language.tuning'
except ImportError:
  !gcloud auth application-default login --client-id-file client_secret.json --scopes='https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/generative-language.tuning'
You are authorizing client libraries without access to a web browser. Please run the following command on a machine with a web browser and copy its output back here. Make sure the installed gcloud version is 372.0.0 or newer.

gcloud auth application-default login --remote-bootstrap="https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=87071151422-n1a3cb6c7fvkfg4gmhdtmn5ulol2l4be.apps.googleusercontent.com&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fgenerative-language.tuning&state=QIyNibWSaTIsozjmvZEkVBo6EcoW0G&access_type=offline&code_challenge=76c1ZiGvKN8cvlYfj3BmbCwE4e7tvrlwaX3REUX25gY&code_challenge_method=S256&token_usage=remote"

Enter the output of the above command: https://localhost:8085/?state=QIyNibWSaTIsozjmvZEkVBo6EcoW0G&code=4/0AeaYSHBKrY911S466QjKQIFODoOPXlO1mWyTYYdrbELIDV6Hw2DKRAyro62BugroSvIWsA&scope=https://www.googleapis.com/auth/cloud-platform%20https://www.googleapis.com/auth/generative-language.tuning

Credentials saved to file: [/content/.config/application_default_credentials.json]

These credentials will be used by any library that requests Application Default Credentials (ADC).

Configura las variables

CURL

Establece variables de valores recurrentes para usar en el resto de las llamadas a la API de REST. El código usa la biblioteca os de Python para establecer variables de entorno a las que se puede acceder en todas las celdas de código.

Esto es específico del entorno del notebook de Colab. El código de la siguiente celda de código es equivalente a ejecutar los siguientes comandos en una terminal de Bash.

export access_token=$(gcloud auth application-default print-access-token)
export project_id=my-project-id
export base_url=https://generativelanguage.googleapis.com
import os

access_token = !gcloud auth application-default print-access-token
access_token = '\n'.join(access_token)

os.environ['access_token'] = access_token
os.environ['project_id'] = "[Enter your project-id here]"
os.environ['base_url'] = "https://generativelanguage.googleapis.com"

Python

access_token = !gcloud auth application-default print-access-token
access_token = '\n'.join(access_token)

project = '[Enter your project-id here]'
base_url = "https://generativelanguage.googleapis.com"

Importa la biblioteca requests.

import requests
import json

Enumerar modelos ajustados

Para verificar tu configuración de autenticación, enumera los modelos ajustados disponibles.

CURL


curl -X GET ${base_url}/v1beta/tunedModels \
    -H "Content-Type: application/json" \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}"

Python

headers={
  'Authorization': 'Bearer ' + access_token,
  'Content-Type': 'application/json',
  'x-goog-user-project': project
}

result = requests.get(
  url=f'{base_url}/v1beta/tunedModels',
  headers = headers,
)
result.json()

Crear modelo ajustado

Para crear un modelo ajustado, debes pasar el conjunto de datos al modelo en el campo training_data.

En este ejemplo, ajustarás un modelo para generar el siguiente número en la secuencia. Por ejemplo, si la entrada es 1, el modelo debería generar 2. Si la entrada es one hundred, el resultado debe ser one hundred one.

CURL


curl -X POST $base_url/v1beta/tunedModels \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \
    -d '
      {
        "display_name": "number generator model",
        "base_model": "models/gemini-1.0-pro-001",
        "tuning_task": {
          "hyperparameters": {
            "batch_size": 2,
            "learning_rate": 0.001,
            "epoch_count":5,
          },
          "training_data": {
            "examples": {
              "examples": [
                {
                    "text_input": "1",
                    "output": "2",
                },{
                    "text_input": "3",
                    "output": "4",
                },{
                    "text_input": "-3",
                    "output": "-2",
                },{
                    "text_input": "twenty two",
                    "output": "twenty three",
                },{
                    "text_input": "two hundred",
                    "output": "two hundred one",
                },{
                    "text_input": "ninety nine",
                    "output": "one hundred",
                },{
                    "text_input": "8",
                    "output": "9",
                },{
                    "text_input": "-98",
                    "output": "-97",
                },{
                    "text_input": "1,000",
                    "output": "1,001",
                },{
                    "text_input": "10,100,000",
                    "output": "10,100,001",
                },{
                    "text_input": "thirteen",
                    "output": "fourteen",
                },{
                    "text_input": "eighty",
                    "output": "eighty one",
                },{
                    "text_input": "one",
                    "output": "two",
                },{
                    "text_input": "three",
                    "output": "four",
                },{
                    "text_input": "seven",
                    "output": "eight",
                }
              ]
            }
          }
        }
      }' | tee tunemodel.json
{
"name": "tunedModels/number-generator-model-dzlmi0gswwqb/operations/bvl8dymw0fhw",
"metadata": {
  "@type": "type.googleapis.com/google.ai.generativelanguage.v1beta.CreateTunedModelMetadata",
  "totalSteps": 38,
  "tunedModel": "tunedModels/number-generator-model-dzlmi0gswwqb"
}
}
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                              Dload  Upload   Total   Spent    Left  Speed
100  2280    0   296  100  1984    611   4098 --:--:-- --:--:-- --:--:--  4720

Python

operation = requests.post(
    url = f'{base_url}/v1beta/tunedModels',
    headers=headers,
    json= {
        "display_name": "number generator",
        "base_model": "models/gemini-1.0-pro-001",
        "tuning_task": {
          "hyperparameters": {
            "batch_size": 4,
            "learning_rate": 0.001,
            "epoch_count":5,
          },
          "training_data": {
            "examples": {
              "examples": [
                {
                    'text_input': '1',
                    'output': '2',
                },{
                    'text_input': '3',
                    'output': '4',
                },{
                    'text_input': '-3',
                    'output': '-2',
                },{
                    'text_input': 'twenty two',
                    'output': 'twenty three',
                },{
                    'text_input': 'two hundred',
                    'output': 'two hundred one',
                },{
                    'text_input': 'ninety nine',
                    'output': 'one hundred',
                },{
                    'text_input': '8',
                    'output': '9',
                },{
                    'text_input': '-98',
                    'output': '-97',
                },{
                    'text_input': '1,000',
                    'output': '1,001',
                },{
                    'text_input': '10,100,000',
                    'output': '10,100,001',
                },{
                    'text_input': 'thirteen',
                    'output': 'fourteen',
                },{
                    'text_input': 'eighty',
                    'output': 'eighty one',
                },{
                    'text_input': 'one',
                    'output': 'two',
                },{
                    'text_input': 'three',
                    'output': 'four',
                },{
                    'text_input': 'seven',
                    'output': 'eight',
                }
              ]
            }
          }
        }
      }
)
operation
<Response [200]>
operation.json()
{'name': 'tunedModels/number-generator-wl1qr34x2py/operations/41vni3zk0a47',
'metadata': {'@type': 'type.googleapis.com/google.ai.generativelanguage.v1beta.CreateTunedModelMetadata',
  'totalSteps': 19,
  'tunedModel': 'tunedModels/number-generator-wl1qr34x2py'} }

Establece una variable con el nombre de tu modelo ajustado para usarla para el resto de las llamadas.

name=operation.json()["metadata"]["tunedModel"]
name
'tunedModels/number-generator-wl1qr34x2py'

Obtén el estado del modelo ajustado

El estado del modelo se establece en CREATING durante el entrenamiento y cambiará a ACTIVE una vez que se complete.

CURL

A continuación, se muestra un fragmento de código de Python para analizar el nombre del modelo generado a partir del JSON de respuesta. Si ejecutas esto en una terminal, puedes intentar usar un analizador JSON de Bash para analizar la respuesta.

import json

first_page = json.load(open('tunemodel.json'))
os.environ['modelname'] = first_page['metadata']['tunedModel']

print(os.environ['modelname'])
tunedModels/number-generator-model-dzlmi0gswwqb

Realiza otra solicitud GET con el nombre del modelo para obtener los metadatos del modelo que incluyen el campo de estado.


curl -X GET ${base_url}/v1beta/${modelname} \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" | grep state
"state": "ACTIVE",
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                              Dload  Upload   Total   Spent    Left  Speed
100  5921    0  5921    0     0  13164      0 --:--:-- --:--:-- --:--:-- 13157

Python

tuned_model = requests.get(
    url = f'{base_url}/v1beta/{name}',
    headers=headers,
)
tuned_model.json()

El siguiente código verifica el campo de estado cada 5 segundos hasta que ya no esté en el estado CREATING.

import time
import pprint

op_json = operation.json()
response = op_json.get('response')
error = op_json.get('error')

while response is None and error is None:
    time.sleep(5)

    operation = requests.get(
        url = f'{base_url}/v1/{op_json["name"]}',
        headers=headers,
    )

    op_json = operation.json()
    response = op_json.get('response')
    error = op_json.get('error')

    percent = op_json['metadata'].get('completedPercent')
    if percent is not None:
      print(f"{percent:.2f}% - {op_json['metadata']['snapshots'][-1]}")
      print()

if error is not None:
    raise Exception(error)
100.00% - {'step': 19, 'epoch': 5, 'meanLoss': 1.402067, 'computeTime': '2024-03-14T15:11:23.766989274Z'}

Ejecuta la inferencia

Una vez que finalice tu trabajo de ajuste, puedes usarlo para generar texto con el servicio de texto.

CURL

Intenta ingresar un número romano, por ejemplo, 63 (LXIII):


curl -X POST $base_url/v1beta/$modelname:generateContent \
    -H 'Content-Type: application/json' \
    -H "Authorization: Bearer ${access_token}" \
    -H "x-goog-user-project: ${project_id}" \
    -d '{
        "contents": [{
        "parts": [{
          "text": "LXIII"
          }]
        }]
        }' 2> /dev/null
{
"candidates": [
  {
    "content": {
      "parts": [
        {
          "text": "LXIV"
        }
      ],
      "role": "model"
    },
    "finishReason": "STOP",
    "index": 0,
    "safetyRatings": [
      {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_HARASSMENT",
        "probability": "NEGLIGIBLE"
      },
      {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "probability": "NEGLIGIBLE"
      }
    ]
  }
],
"promptFeedback": {
  "safetyRatings": [
    {
      "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
      "probability": "NEGLIGIBLE"
    },
    {
      "category": "HARM_CATEGORY_HATE_SPEECH",
      "probability": "NEGLIGIBLE"
    },
    {
      "category": "HARM_CATEGORY_HARASSMENT",
      "probability": "NEGLIGIBLE"
    },
    {
      "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
      "probability": "NEGLIGIBLE"
    }
  ]
}
}

El resultado del modelo puede ser correcto o no. Si el modelo ajustado no cumple con los estándares requeridos, puedes intentar agregar más ejemplos de alta calidad, ajustar los hiperparámetros o agregar un preámbulo a tus ejemplos. Incluso puedes crear otro modelo ajustado basado en el primero que creaste.

Consulta la guía de ajuste para obtener más indicaciones sobre cómo mejorar el rendimiento.

Python

Intenta ingresar un número japonés, por ejemplo, el 6 (GFRG):

import time

m = requests.post(
    url = f'{base_url}/v1beta/{name}:generateContent',
    headers=headers,
    json= {
        "contents": [{
            "parts": [{
                "text": "六"
            }]
          }]
    })
import pprint
pprint.pprint(m.json())
{'candidates': [{'content': {'parts': [{'text': '七'}], 'role': 'model'},
                'finishReason': 'STOP',
                'index': 0,
                'safetyRatings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
                                    'probability': 'NEGLIGIBLE'},
                                  {'category': 'HARM_CATEGORY_HATE_SPEECH',
                                    'probability': 'NEGLIGIBLE'},
                                  {'category': 'HARM_CATEGORY_HARASSMENT',
                                    'probability': 'LOW'},
                                  {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
                                    'probability': 'NEGLIGIBLE'}]}],
'promptFeedback': {'safetyRatings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
                                      'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_HATE_SPEECH',
                                      'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_HARASSMENT',
                                      'probability': 'NEGLIGIBLE'},
                                      {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
                                      'probability': 'NEGLIGIBLE'}]} }

El resultado del modelo puede ser correcto o no. Si el modelo ajustado no cumple con los estándares requeridos, puedes intentar agregar más ejemplos de alta calidad, ajustar los hiperparámetros o agregar un preámbulo a tus ejemplos.

Conclusión

Aunque los datos de entrenamiento no contenían ninguna referencia a los numerales romanos ni japoneses, el modelo pudo generalizarse mucho después del ajuste. De esta manera, puedes ajustar los modelos para que se adapten a tus casos de uso.

Próximos pasos

Para aprender a usar el servicio de ajuste con la ayuda del SDK de Python para la API de Gemini, visita la guía de inicio rápido de ajuste con Python. Para aprender a usar otros servicios en la API de Gemini, visita la guía de inicio rápido de Python.