This tutorial will help you get started with the Gemini API tuning service using either the Python SDK or the REST API using curl. The examples show how to tune the text model behind the Gemini API text generation service.
View on ai.google.dev | Try a Colab notebook | View notebook on GitHub |
Set up authentication
The Gemini API lets you tune models on your own data. Since it's your data and your tuned models this needs stricter access controls than API keys can provide.
Before you can run this tutorial, you'll need to set up OAuth for your project.
List tuned models
You can check your existing tuned models with the genai.list_tuned_models
method.
for model_info in genai.list_tuned_models():
print(model_info.name)
Create a tuned model
To create a tuned model, you need to pass your dataset to the model in the
genai.create_tuned_model
method. You can do this by directly defining the
input and output values in the call or importing from a file into a dataframe to
pass to the method.
For this example, you will tune a model to generate the next number in the
sequence. For example, if the input is 1
, the model should output 2
. If the
input is one hundred
, the output should be one hundred one
.
import time
base_model = "models/gemini-1.5-flash-001-tuning"
training_data = [
{"text_input": "1", "output": "2"},
# ... more examples ...
# ...
{"text_input": "seven", "output": "eight"},
]
operation = genai.create_tuned_model(
# You can use a tuned model here too. Set `source_model="tunedModels/..."`
display_name="increment",
source_model=base_model,
epoch_count=20,
batch_size=4,
learning_rate=0.001,
training_data=training_data,
)
for status in operation.wait_bar():
time.sleep(10)
result = operation.result()
print(result)
# # You can plot the loss curve with:
# snapshots = pd.DataFrame(result.tuning_task.snapshots)
# sns.lineplot(data=snapshots, x='epoch', y='mean_loss')
model = genai.GenerativeModel(model_name=result.name)
result = model.generate_content("III")
print(result.text) # IV
The optimal values for epoch count, batch size, and learning rate are dependent on your dataset and other constraints of your use case. To learn more about these values, see Advanced tuning settings and Hyperparameters.
Since tuning a model can take significant time, this API doesn't wait for the
tuning to complete. Instead, it returns a google.api_core.operation.Operation
object that lets you check on the status of the tuning job, or wait for it to
complete, and check the result.
Your tuned model is immediately added to the list of tuned models, but its state is set to "creating" while the model is tuned.
Check tuning progress
You can check on the progress of the tuning operation using the wait_bar()
method:
for status in operation.wait_bar():
time.sleep(10)
You can also use operation.metadata
to check the total number of tuning steps
and operation.update()
to refresh the status of the operation.
You can cancel your tuning job any time using the cancel()
method.
operation.cancel()
Try the model
You can use the genai.generate_text
method and specify the name of the tuned
model to test its performance.
model = genai.GenerativeModel(model_name="tunedModels/my-increment-model")
result = model.generate_content("III")
print(result.text) # "IV"
Update the description
You can update the description of your tuned model any time using the
genai.update_tuned_model
method.
genai.update_tuned_model('tunedModels/my-increment-model', {"description":"This is my model."})
Delete the model
You can clean up your tuned model list by deleting models you no longer need.
Use the genai.delete_tuned_model
method to delete a model. If you canceled any
tuning jobs, you may want to delete those as their performance may be
unpredictable.
genai.delete_tuned_model("tunedModels/my-increment-model")