TensorLabel

公共类 TensorLabel

TensorLabel 是 TensorBuffer 的一个实用程序封装容器,在轴上具有有意义的标签。

例如,某个图像分类模型可能具有一个形状为 {1, 10} 的输出张量,其中 1 是批次大小,10 是类别数量。实际上,在第二轴上,我们可以使用每个相应类别的名称或说明来标记每个子张量。TensorLabel 有助于将 TensorBuffer 中的普通张量转换为从预定义标签到子张量的映射。在本例中,如果为第 2 个轴提供了 10 个标签,TensorLabel 可以将原始的 {1, 10} 张量转换为包含 10 个元素的映射,其中每个值都是形状为 {}(标量)的张量。用法示例:

   TensorBuffer outputTensor = ...;
   List<String> labels = FileUtil.loadLabels(context, labelFilePath);
   // labels the first axis with size greater than one
   TensorLabel labeled = new TensorLabel(labels, outputTensor);
   // If each sub-tensor has effectively size 1, we can directly get a float value
   Map<String, Float> probabilities = labeled.getMapWithFloatValue();
   // Or get sub-tensors, when each sub-tensor has elements more than 1
   Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
 

注意:目前,我们仅支持对大小大于 1 的第一个标签进行张量到映射转换。

公共构造函数

TensorLabel(Map<Integer, List<String>> axisLabel、TensorBuffer tensorBuffer)
创建一个 TensorLabel 对象,它能够在多维张量的轴上加标签。
TensorLabel(List<String> axisLabel、TensorBuffer tensorBuffer)
创建一个 TensorLabel 对象,它能够在多维张量的一个轴上加标签。

公共方法

列表<类别>
getCategoryList()
TensorLabel 对象获取 Category 的列表。
地图<String浮点>
getMapWithFloatValue()
获取将标签映射到浮点数的映射。
Map<StringTensorBuffer>
getMapWithTensorBuffer()
获取包含一对标签和相应 TensorBuffer 的地图。

继承的方法

公共构造函数

public TensorLabel (Map<Integer, List<String>> axisLabel、TensorBuffer tensorBuffer)

创建一个 TensorLabel 对象,它能够在多维张量的轴上加标签。

参数
axisLabels 映射,其键为轴 ID(从 0 开始),值为相应的标签。注意:标签的大小应与该轴上的张量大小相同。
tensorBuffer 要加标签的 TensorBuffer。
抛出
NullPointerException 如果 axisLabelstensorBuffer 为 null,或者 axisLabels 中的任何值为 null。
IllegalArgumentException 如果 axisLabels 中的任何键超出范围(与 tensorBuffer 的形状相比),或者任何值(标签)的大小与给定维度上的 tensorBuffer 不同。

public TensorLabel (List<String> axisLabel, TensorBuffer tensorBuffer)

创建一个 TensorLabel 对象,它能够在多维张量的一个轴上加标签。

注意:这些标签会应用到大小大于 1 的第一个轴。例如,如果张量的形状为 [1, 10, 3],则标签将应用于轴 1(ID 从 0 开始),并且 axisLabels 的大小也应为 10。

参数
axisLabels 标签列表,其大小应与待加标签轴上的张量大小相同。
tensorBuffer 要加标签的 TensorBuffer。

公共方法

public List<Category> getCategoryList ()

TensorLabel 对象获取 Category 的列表。

标签的轴实际上应该是最后一个轴(这意味着该轴指定的每个子张量的扁平大小都应为 1),以便每个加标签的子张量都可以转换为浮点值得分。示例:形状为 {2, 5, 3} 且轴 2 的 TensorLabel 有效。如果轴为 1 或 0,则无法转换为 Category

getMapWithFloatValue() 是一种替代方法,但会返回 Map 作为结果。

抛出
IllegalStateException 如果每个标签上子张量的大小不是 1,则使用此函数。

public Map<StringFloat> getMapWithFloatValue ()

获取将标签映射到浮点数的映射。仅允许在大小大于 1 的第一个轴上进行映射,并且该轴实际上应该是最后一个轴(这意味着该轴指定的每个子张量的平面大小都应为 1)。

getCategoryList() 是用于获取结果的替代 API。

抛出
IllegalStateException 如果每个标签上子张量的大小不是 1,则使用此函数。

public Map<StringTensorBuffer> getMapWithTensorBuffer ()

获取包含一对标签和相应 TensorBuffer 的地图。目前,仅允许在第一个轴上映射大小大于 1 的轴。