Training a Text Classifier Using Embeddings

View on Run in Google Colab View source on GitHub


In this notebook, you'll learn to use the embeddings produced by the Gemini API to train a model that can classify different types of newsgroup posts based on the topic.

In this tutorial, you'll train a classifier to predict which class a newsgroup post belongs to.


You can run this quickstart in Google Colab.

To complete this quickstart on your own development environment, ensure that your envirmonement meets the following requirements:

  • Python 3.9+
  • An installation of jupyter to run the notebook.


First, download and install the Gemini API Python library.

pip install -U -q google.generativeai
import re
import tqdm
import keras
import numpy as np
import pandas as pd

import google.generativeai as genai

# Used to securely store your API key
from google.colab import userdata

import seaborn as sns
import matplotlib.pyplot as plt

from keras import layers
from matplotlib.ticker import MaxNLocator
from sklearn.datasets import fetch_20newsgroups
import sklearn.metrics as skmetrics

Grab an API Key

Before you can use the Gemini API, you must first obtain an API key. If you don't already have one, create a key with one click in Google AI Studio.

Get an API key

In Colab, add the key to the secrets manager under the "🔑" in the left panel. Give it the name API_KEY.

Once you have the API key, pass it to the SDK. You can do this in two ways:

  • Put the key in the GOOGLE_API_KEY environment variable (the SDK will automatically pick it up from there).
  • Pass the key to genai.configure(api_key=...)
# Or use `os.getenv('API_KEY')` to fetch an environment variable.

for m in genai.list_models():
  if 'embedContent' in m.supported_generation_methods:


The 20 Newsgroups Text Dataset contains 18,000 newsgroups posts on 20 topics divided into training and test sets. The split between the training and test datasets are based on messages posted before and after a specific date. For this tutorial, you will be using the subsets of the training and test datasets. You will preprocess and organize the data into Pandas dataframes.

newsgroups_train = fetch_20newsgroups(subset='train')
newsgroups_test = fetch_20newsgroups(subset='test')

# View list of class names for dataset

Here is an example of what a data point from the training set looks like.

idx =[0].index('Lines')
Lines: 15

 I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.


- IL
   ---- brought to you by your neighborhood Lerxst ----

Now you will begin preprocessing the data for this tutorial. Remove any sensitive information like names, email, or redundant parts of the text like "From: " and "\nSubject: ". Organize the information into a Pandas dataframe so it is more readable.

def preprocess_newsgroup_data(newsgroup_dataset):
  # Apply functions to remove names, emails, and extraneous words from data points in = [re.sub(r'[\w\.-]+@[\w\.-]+', '', d) for d in] # Remove email = [re.sub(r"\([^()]*\)", "", d) for d in] # Remove names = [d.replace("From: ", "") for d in] # Remove "From: " = [d.replace("\nSubject: ", "") for d in] # Remove "\nSubject: "

  # Cut off each text entry after 5,000 characters = [d[0:5000] if len(d) > 5000 else d for d in]

  # Put data points into dataframe
  df_processed = pd.DataFrame(, columns=['Text'])
  df_processed['Label'] =
  # Match label to target name index
  df_processed['Class Name'] = ''
  for idx, row in df_processed.iterrows():[idx, 'Class Name'] = newsgroup_dataset.target_names[row['Label']]

  return df_processed
# Apply preprocessing function to training and test datasets
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)


Next, you will sample some of the data by taking 100 data points in the training dataset, and dropping a few of the categories to run through this tutorial. Choose the science categories to compare.

def sample_data(df, num_samples, classes_to_keep):
  df = df.groupby('Label', as_index = False).apply(lambda x: x.sample(num_samples)).reset_index(drop=True)

  df = df[df['Class Name'].str.contains(classes_to_keep)]

  # Reset the encoding of the labels after sampling and dropping certain categories
  df['Class Name'] = df['Class Name'].astype('category')
  df['Encoded Label'] = df['Class Name']

  return df
CLASSES_TO_KEEP = 'sci' # Class name should contain 'sci' in it to keep science categories
df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)
df_train.value_counts('Class Name')
Class Name
sci.crypt          100
sci.electronics    100            100          100
dtype: int64
df_test.value_counts('Class Name')
Class Name
sci.crypt          25
sci.electronics    25            25          25
dtype: int64

Create the embeddings

In this section, you will see how to generate embeddings for a piece of text using the embeddings from the Gemini API. To learn more about embeddings, visit the embeddings guide.

API changes to Embeddings embedding-001

For the new embeddings model, there is a new task type parameter and the optional title (only valid with task_type=RETRIEVAL_DOCUMENT).

These new parameters apply only to the newest embeddings models.The task types are:

Task Type Description
RETRIEVAL_QUERY Specifies the given text is a query in a search/retrieval setting.
RETRIEVAL_DOCUMENT Specifies the given text is a document in a search/retrieval setting.
SEMANTIC_SIMILARITY Specifies the given text will be used for Semantic Textual Similarity (STS).
CLASSIFICATION Specifies that the embeddings will be used for classification.
CLUSTERING Specifies that the embeddings will be used for clustering.
from import tqdm

from google.api_core import retry

def make_embed_text_fn(model):

  def embed_fn(text: str) -> list[float]:
    # Set the task_type to CLASSIFICATION.
    embedding = genai.embed_content(model=model,
    return embedding['embedding']

  return embed_fn

def create_embeddings(model, df):
  df['Embeddings'] = df['Text'].progress_apply(make_embed_text_fn(model))
  return df
model = 'models/embedding-001'
df_train = create_embeddings(model, df_train)
df_test = create_embeddings(model, df_test)
0%|          | 0/400 [00:00<?, ?it/s]
0%|          | 0/100 [00:00<?, ?it/s]

Build a simple classification model

Here you will define a simple model with one hidden layer and a single class probability output. The prediction will correspond to the probability of a piece of text being a particular class of news. When you build your model, Keras will automatically shuffle the data points.

def build_classification_model(input_size: int, num_classes: int) -> keras.Model:
  inputs = x = keras.Input(input_size)
  x = layers.Dense(input_size, activation='relu')(x)
  x = layers.Dense(num_classes, activation='sigmoid')(x)
  return keras.Model(inputs=[inputs], outputs=x)
# Derive the embedding size from the first training element.
embedding_size = len(df_train['Embeddings'].iloc[0])

# Give your model a different name, as you have already used the variable name 'model'
classifier = build_classification_model(embedding_size, len(df_train['Class Name'].unique()))

classifier.compile(loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                   optimizer = keras.optimizers.Adam(learning_rate=0.001),
Model: "model"
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 768)]             0         
 dense (Dense)               (None, 768)               590592    
 dense_1 (Dense)             (None, 4)                 3076      
Total params: 593668 (2.26 MB)
Trainable params: 593668 (2.26 MB)
Non-trainable params: 0 (0.00 Byte)

Train the model to classify newsgroups

Finally, you can train a simple model. Use a small number of epochs to avoid overfitting. The first epoch takes much longer than the rest, because the embeddings need to be computed only once.


# Split the x and y components of the train and validation subsets.
y_train = df_train['Encoded Label']
x_train = np.stack(df_train['Embeddings'])
y_val = df_test['Encoded Label']
x_val = np.stack(df_test['Embeddings'])

# Train the model for the desired number of epochs.
callback = keras.callbacks.EarlyStopping(monitor='accuracy', patience=3)

history =,
                         validation_data=(x_val, y_val),
Epoch 1/20
/usr/local/lib/python3.10/dist-packages/keras/src/ UserWarning: "`sparse_categorical_crossentropy` received `from_logits=True`, but the `output` argument was produced by a Softmax activation and thus does not represent logits. Was this intended?
  output, from_logits = _get_logits(
13/13 [==============================] - 1s 30ms/step - loss: 1.2141 - accuracy: 0.6675 - val_loss: 0.9801 - val_accuracy: 0.8800
Epoch 2/20
13/13 [==============================] - 0s 12ms/step - loss: 0.7580 - accuracy: 0.9400 - val_loss: 0.6061 - val_accuracy: 0.9300
Epoch 3/20
13/13 [==============================] - 0s 13ms/step - loss: 0.4249 - accuracy: 0.9525 - val_loss: 0.3902 - val_accuracy: 0.9200
Epoch 4/20
13/13 [==============================] - 0s 13ms/step - loss: 0.2561 - accuracy: 0.9625 - val_loss: 0.2597 - val_accuracy: 0.9400
Epoch 5/20
13/13 [==============================] - 0s 13ms/step - loss: 0.1693 - accuracy: 0.9700 - val_loss: 0.2145 - val_accuracy: 0.9300
Epoch 6/20
13/13 [==============================] - 0s 13ms/step - loss: 0.1240 - accuracy: 0.9850 - val_loss: 0.1801 - val_accuracy: 0.9600
Epoch 7/20
13/13 [==============================] - 0s 21ms/step - loss: 0.0931 - accuracy: 0.9875 - val_loss: 0.1623 - val_accuracy: 0.9400
Epoch 8/20
13/13 [==============================] - 0s 16ms/step - loss: 0.0736 - accuracy: 0.9925 - val_loss: 0.1418 - val_accuracy: 0.9600
Epoch 9/20
13/13 [==============================] - 0s 20ms/step - loss: 0.0613 - accuracy: 0.9925 - val_loss: 0.1315 - val_accuracy: 0.9700
Epoch 10/20
13/13 [==============================] - 0s 20ms/step - loss: 0.0479 - accuracy: 0.9975 - val_loss: 0.1235 - val_accuracy: 0.9600
Epoch 11/20
13/13 [==============================] - 0s 19ms/step - loss: 0.0399 - accuracy: 0.9975 - val_loss: 0.1219 - val_accuracy: 0.9700
Epoch 12/20
13/13 [==============================] - 0s 21ms/step - loss: 0.0326 - accuracy: 0.9975 - val_loss: 0.1158 - val_accuracy: 0.9700
Epoch 13/20
13/13 [==============================] - 0s 19ms/step - loss: 0.0263 - accuracy: 1.0000 - val_loss: 0.1127 - val_accuracy: 0.9700
Epoch 14/20
13/13 [==============================] - 0s 17ms/step - loss: 0.0229 - accuracy: 1.0000 - val_loss: 0.1123 - val_accuracy: 0.9700
Epoch 15/20
13/13 [==============================] - 0s 20ms/step - loss: 0.0195 - accuracy: 1.0000 - val_loss: 0.1063 - val_accuracy: 0.9700
Epoch 16/20
13/13 [==============================] - 0s 17ms/step - loss: 0.0172 - accuracy: 1.0000 - val_loss: 0.1070 - val_accuracy: 0.9700

Evaluate model performance

Use Keras Model.evaluate to get the loss and accuracy on the test dataset.

classifier.evaluate(x=x_val, y=y_val, return_dict=True)
4/4 [==============================] - 0s 4ms/step - loss: 0.1070 - accuracy: 0.9700
{'loss': 0.10700511932373047, 'accuracy': 0.9700000286102295}

One way to evaluate your model performance is to visualize the classifier performance. Use plot_history to see the loss and accuracy trends over the epochs.

def plot_history(history):
    Plotting training and validation learning curves.

      history: model history with all the metric measures
  fig, (ax1, ax2) = plt.subplots(1,2)
  fig.set_size_inches(20, 8)

  # Plot loss
  ax1.plot(history.history['loss'], label = 'train')
  ax1.plot(history.history['val_loss'], label = 'test')

  ax1.legend(['Train', 'Validation'])

  # Plot accuracy
  ax2.plot(history.history['accuracy'],  label = 'train')
  ax2.plot(history.history['val_accuracy'], label = 'test')
  ax2.legend(['Train', 'Validation'])



Another way to view model performance, beyond just measuring loss and accuracy is to use a confusion matrix. The confusion matrix allows you to assess the performance of the classification model beyond accuracy. You can see what misclassified points get classified as. In order to build the confusion matrix for this multi-class classification problem, get the actual values in the test set and the predicted values.

Start by generating the predicted class for each example in the validation set using Model.predict().

y_hat = classifier.predict(x=x_val)
y_hat = np.argmax(y_hat, axis=1)
4/4 [==============================] - 0s 4ms/step
labels_dict = dict(zip(df_test['Class Name'], df_test['Encoded Label']))
{'sci.crypt': 0, 'sci.electronics': 1, '': 2, '': 3}
cm = skmetrics.confusion_matrix(y_val, y_hat)
disp = skmetrics.ConfusionMatrixDisplay(confusion_matrix=cm,
plt.title('Confusion matrix for newsgroup test dataset');


Next steps

To learn more about how you can use embeddings, see these other tutorials: