Python 版图片分类指南

借助 MediaPipe Image Classifier 任务,您可以对图片进行分类。您可以使用 此任务以确定图片所代表的一组类别, 学习到的知识。这些说明介绍了如何使用图像分类器 使用 Python。

您可以通过查看网页 演示。对于 功能、模型和配置选项 此任务,请参阅概览

代码示例

图像分类器的示例代码提供了该分类器的完整实现, 供您参考。此代码可帮助您测试此任务, 开始构建自己的图像分类器。您可以查看、运行和修改 图片分类器示例 代码 只需使用网络浏览器即可。

如果您要为 Raspberry Pi 实现图像分类器,请参阅 Raspberry Pi 示例 app

设置

本部分介绍了设置开发环境和 专门用于使用图像分类器的代码项目。有关 设置开发环境以使用 MediaPipe 任务,包括 平台版本要求,请参阅适用于 Python

<ph type="x-smartling-placeholder">

软件包

图像分类器任务是 mediapipe pip 软件包的任务。您可以安装 依赖项:

$ python -m pip install mediapipe
``` ### Imports

Import the following classes to access the Image Classifier task functions:

```python
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision

型号

MediaPipe 图像分类器任务需要一个与此分类兼容的经过训练的模型 任务。如需详细了解适用于图片分类器的经过训练的模型,请参阅 任务概览的“模型”部分

选择并下载模型,然后将其存储在本地目录中。您可以使用 建议的 EfficientNet-Lite0 模型。

model_path = '/absolute/path/to/efficientnet_lite0_int8_2.tflite'

在 Model Name 参数中指定模型的路径,如下所示:

base_options = BaseOptions(model_asset_path=model_path)

创建任务

使用 create_from_options 函数创建任务。通过 “create_from_options”函数接受配置选项,包括正在运行的 模式、显示名称语言区域、结果数上限、置信度阈值 类别的许可名单和拒绝名单如需详细了解配置 选项,请参阅配置概览

图片分类器任务支持 3 种输入数据类型:静态图片、视频文件 和直播视频流选择与输入数据类型对应的标签页,以 了解如何创建任务并运行推理。

映像

import mediapipe as mp

BaseOptions = mp.tasks.BaseOptions
ImageClassifier = mp.tasks.vision.ImageClassifier
ImageClassifierOptions = mp.tasks.vision.ImageClassifierOptions
VisionRunningMode = mp.tasks.vision.RunningMode

options = ImageClassifierOptions(
    base_options=BaseOptions(model_asset_path='/path/to/model.tflite'),
    max_results=5,
    running_mode=VisionRunningMode.IMAGE)

with ImageClassifier.create_from_options(options) as classifier:
  # The classifier is initialized. Use it here.
  # ...
    

视频

import mediapipe as mp

BaseOptions = mp.tasks.BaseOptions
ImageClassifier = mp.tasks.vision.ImageClassifier
ImageClassifierOptions = mp.tasks.vision.ImageClassifierOptions
VisionRunningMode = mp.tasks.vision.RunningMode

options = ImageClassifierOptions(
    base_options=BaseOptions(model_asset_path='/path/to/model.tflite'),
    max_results=5,
    running_mode=VisionRunningMode.VIDEO)

with ImageClassifier.create_from_options(options) as classifier:
  # The classifier is initialized. Use it here.
  # ...
    

直播

import mediapipe as mp

BaseOptions = mp.tasks.BaseOptions
ImageClassifierResult = mp.tasks.vision.ImageClassifier.ImageClassifierResult
ImageClassifier = mp.tasks.vision.ImageClassifier
ImageClassifierOptions = mp.tasks.vision.ImageClassifierOptions
VisionRunningMode = mp.tasks.vision.RunningMode

def print_result(result: ImageClassifierResult, output_image: mp.Image, timestamp_ms: int):
    print('ImageClassifierResult result: {}'.format(result))

options = ImageClassifierOptions(
    base_options=BaseOptions(model_asset_path='/path/to/model.tflite'),
    running_mode=VisionRunningMode.LIVE_STREAM,
    max_results=5,
    result_callback=print_result)

with ImageClassifier.create_from_options(options) as classifier:
  # The classifier is initialized. Use it here.
  # ...
    

如需查看创建用于图片的图片分类器的完整示例,请参阅 示例

配置选项

此任务具有以下适用于 Python 应用的配置选项:

选项名称 说明 值范围 默认值
running_mode 设置任务的运行模式。有三个 模式:

IMAGE:单图输入的模式。

VIDEO:视频已解码帧的模式。

LIVE_STREAM:输入流媒体直播模式 例如来自相机的数据。在此模式下,resultListener 必须为 调用以设置监听器以接收结果 异步执行。
{IMAGE, VIDEO, LIVE_STREAM} IMAGE
display_names_locale 设置要用于 任务模型的元数据(如果有)。默认值为 en, 英语。您可以向自定义模型的元数据中添加本地化标签 使用 TensorFlow Lite Metadata Writer API 语言区域代码 en
max_results 将评分最高的分类结果的可选数量上限设置为 return。如果 <0,则返回所有可用的结果。 任何正数 -1
score_threshold 设置预测分数阈值,以替换 模型元数据(如果有)。低于此值的结果将被拒绝。 任意浮点数 未设置
category_allowlist 设置允许的类别名称的可选列表。如果不为空, 类别名称未包含在此集合中的分类结果 已滤除。重复或未知的类别名称会被忽略。 此选项与 category_denylist 互斥,使用 都会导致错误。 任何字符串 未设置
category_denylist 设置不允许使用的类别名称的可选列表。如果 非空,类别名称在此集中的分类结果将被滤除 。重复或未知的类别名称会被忽略。这个选项 category_allowlist 不包含,同时使用这两个元素会导致错误。 任何字符串 未设置
result_callback 设置结果监听器以接收分类结果 当图像分类器在直播中时异步执行 模式。仅在跑步模式设为“LIVE_STREAM”时才能使用 不适用 未设置

准备数据

将输入准备为图像文件或 Numpy 数组,然后将其转换为 mediapipe.Image 对象。如果您输入的是 网络摄像头,您可以使用外部库,如 OpenCV:以 Numpy 形式加载输入帧 数组。

以下示例解释并展示了如何准备数据以便进行处理, 每种可用的数据类型

映像

import mediapipe as mp

# Load the input image from an image file.
mp_image = mp.Image.create_from_file('/path/to/image')

# Load the input image from a numpy array.
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=numpy_image)
    

视频

import mediapipe as mp

# Use OpenCV’s VideoCapture to load the input video.

# Load the frame rate of the video using OpenCV’s CV_CAP_PROP_FPS
# You’ll need it to calculate the timestamp for each frame.

# Loop through each frame in the video using VideoCapture#read()

# Convert the frame received from OpenCV to a MediaPipe’s Image object.
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=numpy_frame_from_opencv)
    

直播

import mediapipe as mp

# Use OpenCV’s VideoCapture to start capturing from the webcam.

# Create a loop to read the latest frame from the camera using VideoCapture#read()

# Convert the frame received from OpenCV to a MediaPipe’s Image object.
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=numpy_frame_from_opencv)
    

运行任务

您可以调用与跑步模式对应的分类函数来触发 推理。Image Classifier API 将返回 对象。

映像

# Perform image classification on the provided single image.
classification_result = classifier.classify(mp_image)
    

视频

# Calculate the timestamp of the current frame
frame_timestamp_ms = 1000 * frame_index / video_file_fps

# Perform image classification on the video frame.
classification_result = classifier.classify_for_video(mp_image, frame_timestamp_ms)
    

直播

# Send the latest frame to perform image classification.
# Results are sent to the `result_callback` provided in the `ImageClassifierOptions`.
classifier.classify_async(mp_image, frame_timestamp_ms)
    

请注意以下几点:

  • 在视频模式或直播模式下投放广告时,您还必须 为图像分类器任务提供输入帧的时间戳。
  • 在图片或视频模型中运行时,图片分类器任务将 阻塞当前线程,直到它处理完输入图像,或者 帧。
  • 在直播模式下运行时,图像分类器任务不会阻塞 当前线程,但会立即返回。它将调用其结果 监听器,并在每次完成分类结果时生成分类结果 处理输入帧的过程。如果系统调用 classifyAsync 函数, 图像分类器任务正忙于处理另一帧,则该任务会忽略 新的输入帧。

如需查看创建用于图片的图片分类器的完整示例,请参阅 示例

处理和显示结果

运行推理时,图像分类器任务会返回 ImageClassifierResult 对象,该对象包含可能的类别列表 输入图片或帧中的对象。

以下示例展示了此任务的输出数据:

ImageClassifierResult:
 Classifications #0 (single classification head):
  head index: 0
  category #0:
   category name: "/m/01bwb9"
   display name: "Passer domesticus"
   score: 0.91406
   index: 671
  category #1:
   category name: "/m/01bwbt"
   display name: "Passer montanus"
   score: 0.00391
   index: 670

此结果是通过运行 Bird Classifier 获得的 日期:

图像分类器示例代码演示了如何显示分类 结果,请参阅代码 示例 了解详情。