TensorLabel 是 TensorBuffer 的实用程序封装容器,轴上具有有意义的标签。
例如,一个图像分类模型的输出张量可能为 {1, 10},
其中 1 是批次大小,10 是类别数。事实上,在第 2 个轴上
使用每个相应类别的名称或说明来标记每个子张量。TensorLabel 有助于将 TensorBuffer 中的普通张量转换为
传递给子张量。在本例中,如果为第 2 个轴提供 10 个标签,TensorLabel 可以将原始 {1, 10} 张量转换为一个包含 10 个元素的映射,其每个值
是形状为 {}(标量)的张量。用法示例:
TensorBuffer outputTensor = ...; List<String> labels = FileUtil.loadLabels(context, labelFilePath); // labels the first axis with size greater than one TensorLabel labeled = new TensorLabel(labels, outputTensor); // If each sub-tensor has effectively size 1, we can directly get a float value Map<String, Float> probabilities = labeled.getMapWithFloatValue(); // Or get sub-tensors, when each sub-tensor has elements more than 1 Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
注意:目前,我们仅支持大小更大的第一个标签进行张量到映射的转换 大于 1。
公共构造函数
|
TensorLabel(Map<Integer, List<String>> axisLabel, TensorBuffer tensorBuffer)
创建一个 TensorLabel 对象,该对象能够在多维张量的轴上添加标签。
|
|
|
TensorLabel(List<String> axisLabel, TensorBuffer tensorBuffer)
创建一个 TensorLabel 对象,该对象能够在多维张量的一个轴上添加标签。
|
公共方法
| 列表<类别> |
getCategoryList()
从
TensorLabel 对象获取 Category 列表。 |
| Map<String, Float> |
getMapWithFloatValue()
获取将标签映射到浮点数的地图。
|
| Map<String, TensorBuffer> |
getMapWithTensorBuffer()
获取包含一对标签和相应 TensorBuffer 的映射。
|
继承的方法
公共构造函数
<ph type="x-smartling-placeholder"></ph> 公开 TensorLabel (Map<整数, List<String>> axisLabel, TensorBuffer tensorBuffer)
创建一个 TensorLabel 对象,该对象能够在多维张量的轴上添加标签。
参数
| axisLabels | 一个映射,其键为轴 ID(从 0 开始),值对应于 标签。注意:标签的大小应与该轴上的张量大小相同。 |
|---|---|
| tensorBuffer | 要加标签的 TensorBuffer。 |
抛出
| NullPointerException | 如果 axisLabels 或 tensorBuffer 为 null 或任意值,则会发生此错误
axisLabels 的值为 null。 |
|---|---|
| IllegalArgumentException | 如果 axisLabels 中有任何键超出范围(相较于
tensorBuffer 的形状,或者给定维度上的 tensorBuffer 大小不同(标签)。
|
<ph type="x-smartling-placeholder"></ph> 公开 TensorLabel (列表<字符串> axisLabel, TensorBuffer tensorBuffer)
创建一个 TensorLabel 对象,该对象能够在多维张量的一个轴上添加标签。
注意:标签会应用于大小大于 1 的第一个轴。例如,如果
张量的形状为 [1, 10, 3],标签将应用于轴 1(id 从
0),并且 axisLabels 的大小也应为 10。
参数
| axisLabels | 标签列表,其大小应与上的张量大小相同 要加标签的轴。 |
|---|---|
| tensorBuffer | 要加标签的 TensorBuffer。 |
公共方法
<ph type="x-smartling-placeholder"></ph> 公开 列表<类别> getCategoryList ()
从 TensorLabel 对象获取 Category 列表。
标签的轴实际上应该是最后一个轴(这意味着每个子张量
应具有 1) 的平面大小,以便每个标记的子张量都可以
转换为浮点值得分。示例:形状为 {2, 5, 3} 的 TensorLabel
轴 2 有效如果 axis 为 1 或 0,则无法转换为 Category。
getMapWithFloatValue() 是一个替代方案,但会返回 Map,作为
结果。
抛出
| IllegalStateException | 如果每个标签上子张量的大小不是 1,则会发生此错误。 |
|---|
<ph type="x-smartling-placeholder"></ph> 公开 Map<String、Float> getMapWithFloatValue ()
获取将标签映射到浮点数的地图。仅允许在第一个尺寸更大的轴上进行映射 大于 1,且该轴实际上应是最后一个轴(这意味着每个子张量 应具有 1) 的平面尺寸。
getCategoryList() 是用于获取结果的替代 API。
抛出
| IllegalStateException | 如果每个标签上子张量的大小不是 1,则会发生此错误。 |
|---|
<ph type="x-smartling-placeholder"></ph> 公开 Map<String、TensorBuffer> getMapWithTensorBuffer ()
获取包含一对标签和相应 TensorBuffer 的映射。只允许 映射。