Addestramento di un classificatore di testo utilizzando gli incorporamenti

Visualizza su ai.google.dev Prova un blocco note di Colab Visualizza blocco note su GitHub

Panoramica

In questo blocco note imparerai a usare gli incorporamenti prodotti dall'API Gemini per addestrare un modello in grado di classificare diversi tipi di post di newsgroup in base all'argomento.

In questo tutorial, addestrerai un classificatore per prevedere a quale classe appartiene un post di un newsgroup.

Prerequisiti

Puoi eseguire questa guida rapida in Google Colab.

Per completare questa guida rapida nel tuo ambiente di sviluppo, assicurati che l'ambiente soddisfi i seguenti requisiti:

  • Python 3.9 o versioni successive
  • Un'installazione di jupyter per eseguire il blocco note.

Configurazione

Innanzitutto, scarica e installa la libreria Python dell'API Gemini.

pip install -U -q google.generativeai
import re
import tqdm
import keras
import numpy as np
import pandas as pd

import google.generativeai as genai

# Used to securely store your API key
from google.colab import userdata

import seaborn as sns
import matplotlib.pyplot as plt

from keras import layers
from matplotlib.ticker import MaxNLocator
from sklearn.datasets import fetch_20newsgroups
import sklearn.metrics as skmetrics

Procurati una chiave API

Prima di poter utilizzare l'API Gemini, devi ottenere una chiave API. Se non ne hai già una, crea una chiave con un solo clic in Google AI Studio.

Ottenere una chiave API

In Colab, aggiungi la chiave a Secret Manager nella sezione " {/8}" nel riquadro a sinistra. Assegnagli il nome API_KEY.

Una volta ottenuta la chiave API, passala all'SDK. A tale scopo, puoi procedere in uno dei due seguenti modi:

  • Inserisci la chiave nella variabile di ambiente GOOGLE_API_KEY (l'SDK la acquisirà automaticamente da lì).
  • Passa la chiave a genai.configure(api_key=...)
di Gemini Advanced.
genai.configure(api_key=GOOGLE_API_KEY)
for m in genai.list_models():
  if 'embedContent' in m.supported_generation_methods:
    print(m.name)
models/embedding-001
models/embedding-001

Set di dati

Il set di dati di testo 20 Newsgroups contiene 18.000 post di newsgroup su 20 argomenti,suddivisi in set di addestramento e test. La suddivisione tra i set di dati di addestramento e test si basa sui messaggi pubblicati prima e dopo una data specifica. Per questo tutorial, utilizzerai i sottoinsiemi dei set di dati di addestramento e test. Dovrai pre-elaborare e organizzare i dati in dataframe Pandas.

newsgroups_train = fetch_20newsgroups(subset='train')
newsgroups_test = fetch_20newsgroups(subset='test')

# View list of class names for dataset
newsgroups_train.target_names
['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

Ecco un esempio di un punto dati del set di addestramento.

idx = newsgroups_train.data[0].index('Lines')
print(newsgroups_train.data[0][idx:])
Lines: 15

 I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Thanks,

- IL
   ---- brought to you by your neighborhood Lerxst ----

Ora inizierai a pre-elaborare i dati per questo tutorial. Rimuovi eventuali informazioni sensibili come nomi, email o parti ridondanti del testo, ad esempio "From: " e "\nSubject: ". Organizza le informazioni in un DataFrame Pandas in modo che siano più leggibili.

def preprocess_newsgroup_data(newsgroup_dataset):
  # Apply functions to remove names, emails, and extraneous words from data points in newsgroups.data
  newsgroup_dataset.data = [re.sub(r'[\w\.-]+@[\w\.-]+', '', d) for d in newsgroup_dataset.data] # Remove email
  newsgroup_dataset.data = [re.sub(r"\([^()]*\)", "", d) for d in newsgroup_dataset.data] # Remove names
  newsgroup_dataset.data = [d.replace("From: ", "") for d in newsgroup_dataset.data] # Remove "From: "
  newsgroup_dataset.data = [d.replace("\nSubject: ", "") for d in newsgroup_dataset.data] # Remove "\nSubject: "

  # Cut off each text entry after 5,000 characters
  newsgroup_dataset.data = [d[0:5000] if len(d) > 5000 else d for d in newsgroup_dataset.data]

  # Put data points into dataframe
  df_processed = pd.DataFrame(newsgroup_dataset.data, columns=['Text'])
  df_processed['Label'] = newsgroup_dataset.target
  # Match label to target name index
  df_processed['Class Name'] = ''
  for idx, row in df_processed.iterrows():
    df_processed.at[idx, 'Class Name'] = newsgroup_dataset.target_names[row['Label']]

  return df_processed
# Apply preprocessing function to training and test datasets
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)

df_train.head()

Successivamente, campionarai alcuni dati prendendo 100 punti dati nel set di dati di addestramento e eliminando alcune categorie da eseguire in questo tutorial. Scegli le categorie delle scienze da confrontare.

def sample_data(df, num_samples, classes_to_keep):
  df = df.groupby('Label', as_index = False).apply(lambda x: x.sample(num_samples)).reset_index(drop=True)

  df = df[df['Class Name'].str.contains(classes_to_keep)]

  # Reset the encoding of the labels after sampling and dropping certain categories
  df['Class Name'] = df['Class Name'].astype('category')
  df['Encoded Label'] = df['Class Name'].cat.codes

  return df
TRAIN_NUM_SAMPLES = 100
TEST_NUM_SAMPLES = 25
CLASSES_TO_KEEP = 'sci' # Class name should contain 'sci' in it to keep science categories
df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)
df_train.value_counts('Class Name')
Class Name
sci.crypt          100
sci.electronics    100
sci.med            100
sci.space          100
dtype: int64
df_test.value_counts('Class Name')
Class Name
sci.crypt          25
sci.electronics    25
sci.med            25
sci.space          25
dtype: int64

Crea gli incorporamenti

In questa sezione, vedrai come generare incorporamenti per una porzione di testo utilizzando gli incorporamenti dell'API Gemini. Per saperne di più sugli incorporamenti, consulta la guida agli incorporamenti.

Modifiche all'API Embeddings, incorporamento-001

Per il nuovo modello degli incorporamenti, sono disponibili un nuovo parametro per il tipo di attività e il titolo facoltativo (valido solo con task_type=RETRIEVAL_DOCUMENT).

Questi nuovi parametri si applicano solo ai modelli di incorporamenti più recenti.I tipi di attività sono:

Tipo di attività Descrizione
RETRIEVAL_QUERY Specifica che il testo specificato è una query in un'impostazione di ricerca/recupero.
RETRIEVAL_DOCUMENT Specifica che il testo specificato è un documento in un'impostazione di ricerca/recupero.
SEMANTIC_SIMILARITY Specifica il testo specificato che verrà utilizzato per la somiglianza testuale semantica (STS).
CLASSIFICAZIONE Specifica che gli incorporamenti verranno utilizzati per la classificazione.
CLUSTERING Specifica che gli incorporamenti verranno utilizzati per il clustering.
from tqdm.auto import tqdm
tqdm.pandas()

from google.api_core import retry

def make_embed_text_fn(model):

  @retry.Retry(timeout=300.0)
  def embed_fn(text: str) -> list[float]:
    # Set the task_type to CLASSIFICATION.
    embedding = genai.embed_content(model=model,
                                    content=text,
                                    task_type="classification")
    return embedding['embedding']

  return embed_fn

def create_embeddings(model, df):
  df['Embeddings'] = df['Text'].progress_apply(make_embed_text_fn(model))
  return df
model = 'models/embedding-001'
df_train = create_embeddings(model, df_train)
df_test = create_embeddings(model, df_test)
0%|          | 0/400 [00:00<?, ?it/s]
0%|          | 0/100 [00:00<?, ?it/s]
df_train.head()

Creare un modello di classificazione semplice

Qui definirai un modello semplice con uno strato nascosto e un singolo output di probabilità di classe. La previsione corrisponderà alla probabilità che una porzione di testo sia una particolare classe di notizie. Quando crei il modello, Keras esegue automaticamente lo shuffling dei punti dati.

def build_classification_model(input_size: int, num_classes: int) -> keras.Model:
  inputs = x = keras.Input(input_size)
  x = layers.Dense(input_size, activation='relu')(x)
  x = layers.Dense(num_classes, activation='sigmoid')(x)
  return keras.Model(inputs=[inputs], outputs=x)
# Derive the embedding size from the first training element.
embedding_size = len(df_train['Embeddings'].iloc[0])

# Give your model a different name, as you have already used the variable name 'model'
classifier = build_classification_model(embedding_size, len(df_train['Class Name'].unique()))
classifier.summary()

classifier.compile(loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                   optimizer = keras.optimizers.Adam(learning_rate=0.001),
                   metrics=['accuracy'])
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 768)]             0         
                                                                 
 dense (Dense)               (None, 768)               590592    
                                                                 
 dense_1 (Dense)             (None, 4)                 3076      
                                                                 
=================================================================
Total params: 593668 (2.26 MB)
Trainable params: 593668 (2.26 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
embedding_size
768

Addestra il modello per classificare i newsgroup

Infine, puoi addestrare un modello semplice. Utilizza un numero ridotto di epoche per evitare l'overfitting. La prima epoca richiede molto più tempo delle altre, perché gli incorporamenti devono essere calcolati una sola volta.

NUM_EPOCHS = 20
BATCH_SIZE = 32

# Split the x and y components of the train and validation subsets.
y_train = df_train['Encoded Label']
x_train = np.stack(df_train['Embeddings'])
y_val = df_test['Encoded Label']
x_val = np.stack(df_test['Embeddings'])

# Train the model for the desired number of epochs.
callback = keras.callbacks.EarlyStopping(monitor='accuracy', patience=3)

history = classifier.fit(x=x_train,
                         y=y_train,
                         validation_data=(x_val, y_val),
                         callbacks=[callback],
                         batch_size=BATCH_SIZE,
                         epochs=NUM_EPOCHS,)
Epoch 1/20
/usr/local/lib/python3.10/dist-packages/keras/src/backend.py:5729: UserWarning: "`sparse_categorical_crossentropy` received `from_logits=True`, but the `output` argument was produced by a Softmax activation and thus does not represent logits. Was this intended?
  output, from_logits = _get_logits(
13/13 [==============================] - 1s 30ms/step - loss: 1.2141 - accuracy: 0.6675 - val_loss: 0.9801 - val_accuracy: 0.8800
Epoch 2/20
13/13 [==============================] - 0s 12ms/step - loss: 0.7580 - accuracy: 0.9400 - val_loss: 0.6061 - val_accuracy: 0.9300
Epoch 3/20
13/13 [==============================] - 0s 13ms/step - loss: 0.4249 - accuracy: 0.9525 - val_loss: 0.3902 - val_accuracy: 0.9200
Epoch 4/20
13/13 [==============================] - 0s 13ms/step - loss: 0.2561 - accuracy: 0.9625 - val_loss: 0.2597 - val_accuracy: 0.9400
Epoch 5/20
13/13 [==============================] - 0s 13ms/step - loss: 0.1693 - accuracy: 0.9700 - val_loss: 0.2145 - val_accuracy: 0.9300
Epoch 6/20
13/13 [==============================] - 0s 13ms/step - loss: 0.1240 - accuracy: 0.9850 - val_loss: 0.1801 - val_accuracy: 0.9600
Epoch 7/20
13/13 [==============================] - 0s 21ms/step - loss: 0.0931 - accuracy: 0.9875 - val_loss: 0.1623 - val_accuracy: 0.9400
Epoch 8/20
13/13 [==============================] - 0s 16ms/step - loss: 0.0736 - accuracy: 0.9925 - val_loss: 0.1418 - val_accuracy: 0.9600
Epoch 9/20
13/13 [==============================] - 0s 20ms/step - loss: 0.0613 - accuracy: 0.9925 - val_loss: 0.1315 - val_accuracy: 0.9700
Epoch 10/20
13/13 [==============================] - 0s 20ms/step - loss: 0.0479 - accuracy: 0.9975 - val_loss: 0.1235 - val_accuracy: 0.9600
Epoch 11/20
13/13 [==============================] - 0s 19ms/step - loss: 0.0399 - accuracy: 0.9975 - val_loss: 0.1219 - val_accuracy: 0.9700
Epoch 12/20
13/13 [==============================] - 0s 21ms/step - loss: 0.0326 - accuracy: 0.9975 - val_loss: 0.1158 - val_accuracy: 0.9700
Epoch 13/20
13/13 [==============================] - 0s 19ms/step - loss: 0.0263 - accuracy: 1.0000 - val_loss: 0.1127 - val_accuracy: 0.9700
Epoch 14/20
13/13 [==============================] - 0s 17ms/step - loss: 0.0229 - accuracy: 1.0000 - val_loss: 0.1123 - val_accuracy: 0.9700
Epoch 15/20
13/13 [==============================] - 0s 20ms/step - loss: 0.0195 - accuracy: 1.0000 - val_loss: 0.1063 - val_accuracy: 0.9700
Epoch 16/20
13/13 [==============================] - 0s 17ms/step - loss: 0.0172 - accuracy: 1.0000 - val_loss: 0.1070 - val_accuracy: 0.9700

Valuta le prestazioni del modello

Utilizzo di Keras Model.evaluate per ottenere la perdita e la precisione sul set di dati di test.

classifier.evaluate(x=x_val, y=y_val, return_dict=True)
4/4 [==============================] - 0s 4ms/step - loss: 0.1070 - accuracy: 0.9700
{'loss': 0.10700511932373047, 'accuracy': 0.9700000286102295}

Un modo per valutare le prestazioni del modello è visualizzare le prestazioni del classificatore. Utilizza plot_history per visualizzare le tendenze di perdita e accuratezza nelle epoche.

def plot_history(history):
  """
    Plotting training and validation learning curves.

    Args:
      history: model history with all the metric measures
  """
  fig, (ax1, ax2) = plt.subplots(1,2)
  fig.set_size_inches(20, 8)

  # Plot loss
  ax1.set_title('Loss')
  ax1.plot(history.history['loss'], label = 'train')
  ax1.plot(history.history['val_loss'], label = 'test')
  ax1.set_ylabel('Loss')

  ax1.set_xlabel('Epoch')
  ax1.legend(['Train', 'Validation'])

  # Plot accuracy
  ax2.set_title('Accuracy')
  ax2.plot(history.history['accuracy'],  label = 'train')
  ax2.plot(history.history['val_accuracy'], label = 'test')
  ax2.set_ylabel('Accuracy')
  ax2.set_xlabel('Epoch')
  ax2.legend(['Train', 'Validation'])

  plt.show()

plot_history(history)

png

Un altro modo per visualizzare le prestazioni del modello, oltre alla semplice misurazione della perdita e dell'accuratezza, è utilizzare una matrice di confusione. La matrice di confusione consente di valutare le prestazioni del modello di classificazione oltre l'accuratezza. Puoi vedere come vengono classificati i punti classificati in modo errato. Per creare la matrice di confusione per questo problema di classificazione multiclasse, ottieni i valori effettivi nel set di test e i valori previsti.

Per iniziare, genera la classe prevista per ogni esempio nel set di convalida utilizzando Model.predict().

y_hat = classifier.predict(x=x_val)
y_hat = np.argmax(y_hat, axis=1)
4/4 [==============================] - 0s 4ms/step
labels_dict = dict(zip(df_test['Class Name'], df_test['Encoded Label']))
labels_dict
{'sci.crypt': 0, 'sci.electronics': 1, 'sci.med': 2, 'sci.space': 3}
cm = skmetrics.confusion_matrix(y_val, y_hat)
disp = skmetrics.ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=labels_dict.keys())
disp.plot(xticks_rotation='vertical')
plt.title('Confusion matrix for newsgroup test dataset');
plt.grid(False)

png

Passaggi successivi

Per saperne di più su come utilizzare gli incorporamenti, guarda questi altri tutorial: