TensorLabel

公共类 TensorLabel

TensorLabel 是 TensorBuffer 的实用程序封装容器,轴上具有有意义的标签。

例如,一个图像分类模型的输出张量可能为 {1, 10}, 其中 1 是批次大小,10 是类别数。事实上,在第 2 个轴上 使用每个相应类别的名称或说明来标记每个子张量。TensorLabel 有助于将 TensorBuffer 中的普通张量转换为 传递给子张量。在本例中,如果为第 2 个轴提供 10 个标签,TensorLabel 可以将原始 {1, 10} 张量转换为一个包含 10 个元素的映射,其每个值 是形状为 {}(标量)的张量。用法示例:

   TensorBuffer outputTensor = ...;
   List<String> labels = FileUtil.loadLabels(context, labelFilePath);
   // labels the first axis with size greater than one
   TensorLabel labeled = new TensorLabel(labels, outputTensor);
   // If each sub-tensor has effectively size 1, we can directly get a float value
   Map<String, Float> probabilities = labeled.getMapWithFloatValue();
   // Or get sub-tensors, when each sub-tensor has elements more than 1
   Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
 

注意:目前,我们仅支持大小更大的第一个标签进行张量到映射的转换 大于 1。

公共构造函数

TensorLabel(Map<IntegerList<String>> axisLabel, TensorBuffer tensorBuffer)
创建一个 TensorLabel 对象,该对象能够在多维张量的轴上添加标签。
TensorLabel(List<String> axisLabel, TensorBuffer tensorBuffer)
创建一个 TensorLabel 对象,该对象能够在多维张量的一个轴上添加标签。

公共方法

列表<类别>
getCategoryList()
TensorLabel 对象获取 Category 列表。
Map<StringFloat>
getMapWithFloatValue()
获取将标签映射到浮点数的地图。
Map<StringTensorBuffer>
getMapWithTensorBuffer()
获取包含一对标签和相应 TensorBuffer 的映射。

继承的方法

公共构造函数

<ph type="x-smartling-placeholder"></ph> 公开 TensorLabel (Map<整数List<String>> axisLabel, TensorBuffer tensorBuffer)

创建一个 TensorLabel 对象,该对象能够在多维张量的轴上添加标签。

参数
axisLabels 一个映射,其键为轴 ID(从 0 开始),值对应于 标签。注意:标签的大小应与该轴上的张量大小相同。
tensorBuffer 要加标签的 TensorBuffer。
抛出
NullPointerException 如果 axisLabelstensorBuffer 为 null 或任意值,则会发生此错误 axisLabels 的值为 null。
IllegalArgumentException 如果 axisLabels 中有任何键超出范围(相较于 tensorBuffer 的形状,或者给定维度上的 tensorBuffer 大小不同(标签)。

<ph type="x-smartling-placeholder"></ph> 公开 TensorLabel (列表<字符串> axisLabel, TensorBuffer tensorBuffer)

创建一个 TensorLabel 对象,该对象能够在多维张量的一个轴上添加标签。

注意:标签会应用于大小大于 1 的第一个轴。例如,如果 张量的形状为 [1, 10, 3],标签将应用于轴 1(id 从 0),并且 axisLabels 的大小也应为 10。

参数
axisLabels 标签列表,其大小应与上的张量大小相同 要加标签的轴。
tensorBuffer 要加标签的 TensorBuffer。

公共方法

<ph type="x-smartling-placeholder"></ph> 公开 列表<类别> getCategoryList ()

TensorLabel 对象获取 Category 列表。

标签的轴实际上应该是最后一个轴(这意味着每个子张量 应具有 1) 的平面大小,以便每个标记的子张量都可以 转换为浮点值得分。示例:形状为 {2, 5, 3}TensorLabel 轴 2 有效如果 axis 为 1 或 0,则无法转换为 Category

getMapWithFloatValue() 是一个替代方案,但会返回 Map,作为 结果。

抛出
IllegalStateException 如果每个标签上子张量的大小不是 1,则会发生此错误。

<ph type="x-smartling-placeholder"></ph> 公开 Map<StringFloat> getMapWithFloatValue ()

获取将标签映射到浮点数的地图。仅允许在第一个尺寸更大的轴上进行映射 大于 1,且该轴实际上应是最后一个轴(这意味着每个子张量 应具有 1) 的平面尺寸。

getCategoryList() 是用于获取结果的替代 API。

抛出
IllegalStateException 如果每个标签上子张量的大小不是 1,则会发生此错误。

<ph type="x-smartling-placeholder"></ph> 公开 Map<StringTensorBuffer> getMapWithTensorBuffer ()

获取包含一对标签和相应 TensorBuffer 的映射。只允许 映射。