TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis.
For example, an image classification model may have an output tensor with shape as {1, 10},
where 1 is the batch size and 10 is the number of categories. In fact, on the 2nd axis, we could
label each sub-tensor with the name or description of each corresponding category. TensorLabel
could help converting the plain Tensor in TensorBuffer
into a map from
predefined labels to sub-tensors. In this case, if provided 10 labels for the 2nd axis, TensorLabel
could convert the original {1, 10} Tensor to a 10 element map, each value of which
is Tensor in shape {} (scalar). Usage example:
TensorBuffer outputTensor = ...; List<String> labels = FileUtil.loadLabels(context, labelFilePath); // labels the first axis with size greater than one TensorLabel labeled = new TensorLabel(labels, outputTensor); // If each sub-tensor has effectively size 1, we can directly get a float value Map<String, Float> probabilities = labeled.getMapWithFloatValue(); // Or get sub-tensors, when each sub-tensor has elements more than 1 Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
Note: currently we only support tensor-to-map conversion for the first label with size greater than 1.
Public Constructors
TensorLabel(Map<Integer, List<String>> axisLabels, TensorBuffer tensorBuffer)
Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
|
|
TensorLabel(List<String> axisLabels, TensorBuffer tensorBuffer)
Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
|
Public Methods
List<Category> |
getCategoryList()
Gets a list of
Category from the TensorLabel object. |
Map<String, Float> |
getMapWithFloatValue()
Gets a map that maps label to float.
|
Map<String, TensorBuffer> |
getMapWithTensorBuffer()
Gets the map with a pair of the label and the corresponding TensorBuffer.
|