Usa herramientas externas con generate_text

Ver en ai.google.dev Ejecutar en Google Colab Ver el código fuente en GitHub

En algunos casos de uso, es posible que desees detener la generación de un modelo para insertar resultados específicos. Por ejemplo, los modelos de lenguaje pueden tener problemas con problemas aritméticos complicados, como problemas verbales. En este instructivo, se muestra un ejemplo del uso de una herramienta externa con el método genai.generate_text para obtener la respuesta correcta a un problema verbal.

En este ejemplo en particular, se usa la herramienta numexpr para realizar la aritmética, pero puedes usar el mismo procedimiento para integrar otras herramientas específicas de tu caso de uso. A continuación, se describen los pasos:

  1. Determina una etiqueta start y end para demarcar el texto y enviar la herramienta.
  2. Crea una instrucción que indique al modelo cómo usar las etiquetas en su resultado.
  3. Incluye la etiqueta end en el elemento de stop_sequences que se pasa a generate_text.
  4. Desde el resultado del modelo, toma el texto entre las etiquetas start y end como una entrada para la herramienta.
  5. Ejecuta la herramienta y agrega su resultado a la instrucción.
  6. Vuelve a llamar a generate_text para que el modelo continúe con el resultado de la herramienta.

Configuración

pip install -q google.generativeai
import google.generativeai as genai
genai.configure(api_key='YOUR API KEY')

from google.api_core import retry

@retry.Retry()
def generate_text(*args, **kwargs):
  return genai.generate_text(*args, **kwargs)
models = [m for m in genai.list_models() if 'generateText' in m.supported_generation_methods]
model = models[0].name
print(model)
models/text-bison-001

Tratar de resolver el problema directamente

Este es el problema de palabras que resolverás:

question = """
I have 77 houses, each with 31 cats.
Each cat owns 14 mittens, and 6 hats.
Each mitten was knit from 141m of yarn, each hat from 55m.
How much yarn was needed to make all the items?
"""
prompt_template = """
You are an expert at solving word problems. Here's one:

{question}

Work through it step by step, and show your work.
One step per line.

Your solution:
"""

Pruébalo tal como está:

completion = generate_text(
    model=model,
    prompt=prompt_template.format(question=question),
    # The maximum length of the response
    max_output_tokens=800,
)

print(completion.result)
In the houses there are 77 * 31 = 2387 cats.
So they need 2387 * 14 = 33418 mittens.
And they need 2387 * 6 = 14322 hats.
In total they need 33418 * 141 + 14322 * 55 = 5554525m of yarn.
The answer: 5554525.

Este mensaje suele mostrar un resultado incorrecto. Por lo general, da los pasos correctos, pero la aritmética falla.

La respuesta debería ser:

answer = 77*31*14*141 + 77*31*6*55
answer
5499648

Pedirle al modelo que use una calculadora

En el próximo intento, dale al modelo instrucciones para acceder a la calculadora. Puedes hacerlo especificando las etiquetas start y end que el modelo puede usar para indicar dónde se necesita un cálculo. Agrega algo como lo siguiente a la instrucción:

calc_prompt_template = """
You are an expert at solving word problems. Here's a question:

{question}

-------------------

When solving this problem, use the calculator for any arithmetic.

To use the calculator, put an expression between <calc></calc> tags.
The answer will be printed after the </calc> tag.

For example: 2 houses  * 8 cats/house = <calc>2 * 8</calc> = 16 cats

-------------------

Work through it step by step, and show your work.
One step per line.

Your solution:
"""

calc_prompt = calc_prompt_template.format(question=question)

Para darle al modelo acceso al resultado de esta “calculadora”, deberás pausar la generación y, luego, insertar el resultado. Usa el argumento stop_sequences para detenerte en la etiqueta </calc>:

completion = generate_text(
    model=model,
    prompt=calc_prompt,
    stop_sequences=["</calc>"],
    # The maximum length of the response
    max_output_tokens=800,
    candidate_count=1,
)

result = completion.result
print(result)
In each house, there are <calc>31 * 14

El stop_sequence no se incluye en el resultado. Divide la expresión, ejecútala en la calculadora y agrégala nuevamente al resultado:

# Use re to clear units from the calculator expressions
import re
# Use numexpr since `eval` is unsafe.
import numexpr


def calculator(result):
  result, expression = result.rsplit('<calc>', 1)

  # Strip any units like "cats / house"
  clean_expression = re.sub("[a-zA-Z]([ /a-zA-Z]*[a-zA-Z])?",'', expression)

  # `eval` is unsafe use numexpr
  result = f"{result}<calc>{expression}</calc> = {str(numexpr.evaluate(clean_expression))}"
  return result
print(calculator(result))
In each house, there are <calc>31 * 14</calc> = 434

Ahora agrégalo a la instrucción y ejecuta el modelo de nuevo para que pueda continuar donde lo dejó:

continue_prompt=calc_prompt +"\n"+ "-"*80 + "\n" + calculator(result)

completion = generate_text(
    model=model,
    prompt=continue_prompt,
    stop_sequences=["</calc>"],
    # The maximum length of the response
    max_output_tokens=800,
    candidate_count=1,
)

print(completion.result)
mittens.
In each house, there are <calc>31 * 6

Esta vez, el modelo continuó con el texto del último cálculo y pasó al siguiente. Ahora, ejecútalo en un bucle para resolver por completo el problema verbal:

def solve(question=question):
  results = []

  for n in range(10):
    prompt = calc_prompt_template.format(question=question)

    prompt += " ".join(results)

    completion = generate_text(
        model=model,
        prompt=prompt,
        stop_sequences=["</calc>"],
        # The maximum length of the response
        max_output_tokens=800,
    )

    result = completion.result
    if '<calc>' in result:
      result = calculator(result)

    results.append(result)
    print('-'*40)
    print(result)
    if str(answer) in result:
      break
    if "<calc>" not in  result:
      break

  is_good = any(str(answer) in r for r in results)

  print("*"*100)
  if is_good:
    print("Success!")
  else:
    print("Failure!")
  print("*"*100)

  return is_good
solve(question);
----------------------------------------
The total number of cats is <calc>77 * 31</calc> = 2387
----------------------------------------
cats.
The total number of mittens is <calc>2387 * 14</calc> = 33418
----------------------------------------
mittens.
The total amount of yarn needed for the mittens is <calc>33418 * 141</calc> = 4711938
----------------------------------------
m.
The total number of hats is <calc>2387 * 6</calc> = 14322
----------------------------------------
hats.
 The total amount of yarn needed for the hats is <calc>14322 * 55</calc> = 787710
----------------------------------------
m.
In total, <calc>4711938 + 787710</calc> = 5499648
********************************************************************************
Success!

Puedes ejecutarlo varias veces para estimar la tasa de resolución:

good = []

for n in range(10):
  good.append(solve(question))
----------------------------------------
There are <calc>77 * 31</calc> = 2387
----------------------------------------
cats.
They need <calc>2387 * 14</calc> = 33418
----------------------------------------
mittens.
The mittens need <calc>33418 * 141</calc> = 4711938
----------------------------------------
m of yarn.
They need <calc>2387 * 6</calc> = 14322
----------------------------------------
hats.
The hats need <calc>14322 * 55</calc> = 787710
----------------------------------------
m of yarn.
 They need a total of <calc>4711938 + 787710</calc> = 5499648
********************************************************************************
Success!
----------------------------------------
There are <calc>77 * 31</calc> = 2387
----------------------------------------
cats.
So for the mittens, we need <calc>2387 * 14</calc> = 33418
----------------------------------------
mittens.
That means we need <calc>33418 * 141</calc> = 4711938
----------------------------------------
m of yarn for mittens.
For the hats, we need <calc>2387 * 6</calc> = 14322
----------------------------------------
hats.
That means we need <calc>14322 * 55</calc> = 787710
----------------------------------------
m of yarn for hats.
 In total we need <calc>787710 + 4711938</calc> = 5499648
********************************************************************************
Success!
----------------------------------------
In the 77 houses I have <calc>77 * 31</calc> = 2387
----------------------------------------
cats.
They need <calc>2387 * 14</calc> = 33418
----------------------------------------
mittens.
The mittens need <calc>33418 * 141</calc> = 4711938
----------------------------------------
m of yarn.
They need <calc>2387 * 6</calc> = 14322
----------------------------------------
hats.
The hats need <calc>14322 * 55</calc> = 787710
----------------------------------------
m of yarn.
 So, in total I need <calc>4711938 + 787710</calc> = 5499648
********************************************************************************
Success!
----------------------------------------
The number of cats is <calc>77 * 31</calc> = 2387
----------------------------------------
. Each cat needs <calc>14 * 141</calc> = 1974
----------------------------------------
m of yarn for mittens. So we need <calc>1974 * 2387</calc> = 4711938
----------------------------------------
m of yarn for mittens. Each cat needs <calc>6 * 55</calc> = 330
----------------------------------------
m of yarn for hats. So we need <calc>330 * 2387</calc> = 787710
----------------------------------------
m of yarn for hats. So in total we need <calc>4711938 + 787710</calc> = 5499648
********************************************************************************
Success!
----------------------------------------
There are <calc>77 * 31</calc> = 2387
----------------------------------------
cats.
Each cat needs <calc>14 * 141</calc> = 1974
----------------------------------------
yarn for mittens.
All cats need <calc>2387 * 1974</calc> = 4711938
----------------------------------------
yarn for mittens.
Each cat needs <calc>6 * 55</calc> = 330
----------------------------------------
yarn for hats.
All cats need <calc>2387 * 330</calc> = 787710
----------------------------------------
yarn for hats.
 All in all, you need <calc>4711938 + 787710</calc> = 5499648
********************************************************************************
Success!
----------------------------------------
There are <calc>77 * 31</calc> = 2387
----------------------------------------
cats.
Each cat needs <calc>14 + 6</calc> = 20
----------------------------------------
items.
So we need <calc>20 * 2387</calc> = 47740
----------------------------------------
items in total.
Each mitten needs <calc>141</calc> = 141
----------------------------------------
m of yarn.
So all the mittens need <calc>141 * 47740</calc> = 6731340
----------------------------------------
m of yarn.
 Each hat needs <calc>55</calc> = 55
----------------------------------------
m of yarn.
So all the hats need <calc>55 * 47740</calc> = 2625700
----------------------------------------
m of yarn.
 In total, we need <calc>6731340 + 2625700</calc> = 9357040
----------------------------------------
m of yarn. There are <calc>77 * 31</calc> = 2387
----------------------------------------
cats.
Each cat needs <calc>14 + 6</calc> = 20
********************************************************************************
Failure!
----------------------------------------
There are <calc>77 * 31</calc> = 2387
----------------------------------------
cats.
There are <calc>2387 * 14</calc> = 33418
----------------------------------------
mittens.
There are <calc>2387 * 6</calc> = 14322
----------------------------------------
hats.
There was <calc>141 * 33418</calc> = 4711938
----------------------------------------
m of yarn needed for mittens.
There was <calc>55 * 14322</calc> = 787710
----------------------------------------
m of yarn needed for hats.
 So there was <calc>4711938 + 787710</calc> = 5499648
********************************************************************************
Success!
----------------------------------------
There are <calc>77 * 31</calc> = 2387
----------------------------------------
cats in total. 
They need <calc>2387 * 14</calc> = 33418
----------------------------------------
mittens. 
That's <calc>33418 * 141</calc> = 4711938
----------------------------------------
meters of yarn for mittens. 
They need <calc>2387 * 6</calc> = 14322
----------------------------------------
hats. 
That's <calc>14322 * 55</calc> = 787710
----------------------------------------
meters of yarn for hats. 
So, they need <calc>4711938 + 787710</calc> = 5499648
********************************************************************************
Success!
----------------------------------------
There are 77 houses * 31 cats / house = <calc>77 * 31</calc> = 2387
----------------------------------------
cats.
So we need <calc>2387 * 14</calc> = 33418
----------------------------------------
mittens.
So we need <calc>33418 * 141</calc> = 4711938
----------------------------------------
m of yarn for mittens.
So we need <calc>2387 * 6</calc> = 14322
----------------------------------------
hats.
 So we need <calc>14322 * 55</calc> = 787710
----------------------------------------
m of yarn for hats.
In total, we need <calc>4711938 + 787710</calc> = 5499648
********************************************************************************
Success!
----------------------------------------
In total there are 77 houses * 31 cats / house = <calc>77 * 31</calc> = 2387
----------------------------------------
cats. In total 2387 cats * 14 mittens / cat = <calc>2387 * 14</calc> = 33418
----------------------------------------
mittens. In total 33418 mittens * 141m / mitten = <calc>33418 * 141</calc> = 4711938
----------------------------------------
m of yarn for mittens. In total 2387 cats * 6 hats / cat = <calc>2387 * 6</calc> = 14322
----------------------------------------
hats. In total 14322 hats * 55m / hat = <calc>14322 * 55</calc> = 787710
----------------------------------------
m of yarn for hats. In total we need 4711938 m of yarn for mittens + 787710 m of yarn for hats = <calc>4711938 + 787710</calc> = 5499648
********************************************************************************
Success!
import numpy as np
np.mean(good)
0.9