O TensorLabel é um wrapper util para TensorBuffers com rótulos significativos em um eixo.
Por exemplo, um modelo de classificação de imagem pode ter um tensor de saída com a forma {1, 10},
em que 1 é o tamanho do lote e 10 é o número de categorias. No segundo eixo, poderíamos
identifique cada subtensor com o nome ou a descrição de cada categoria correspondente. TensorLabel
pode ajudar a converter o tensor simples em TensorBuffer
em um mapa a partir de
rótulos predefinidos para subtensores. Nesse caso, se fossem fornecidos 10 rótulos para o segundo eixo, TensorLabel
poderia converter o tensor {1, 10} original em um mapa com 10 elementos. Cada valor poderia ser convertido em
é o tensor em forma {} (escalar). Exemplo de uso:
TensorBuffer outputTensor = ...; List<String> labels = FileUtil.loadLabels(context, labelFilePath); // labels the first axis with size greater than one TensorLabel labeled = new TensorLabel(labels, outputTensor); // If each sub-tensor has effectively size 1, we can directly get a float value Map<String, Float> probabilities = labeled.getMapWithFloatValue(); // Or get sub-tensors, when each sub-tensor has elements more than 1 Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
Observação: no momento, só aceitamos a conversão de tensor para mapa para o primeiro rótulo com tamanho maior de 1.
Construtores públicos
TensorLabel(Map<Integer, List<String>> axisLabels, TensorBuffer tensorBuffer)
Cria um objeto TensorLabel capaz de rotular os eixos de tensores multidimensionais.
|
|
TensorLabel(List<String> axisLabels, TensorBuffer tensorBuffer)
Cria um objeto TensorLabel capaz de rotular em um eixo de tensores multidimensionais.
|
Métodos públicos
Lista<Categoria> |
getCategoryList()
Recebe uma lista de
Category do objeto TensorLabel . |
Map<String, Float> |
getMapWithFloatValue()
Recebe um mapa que mapeia o rótulo para flutuar.
|
Map<String, TensorBuffer> |
getMapWithTensorBuffer()
Recebe o mapa com um par do rótulo e o TensorBuffer correspondente.
|
Métodos herdados
Construtores públicos
públicas TensorLabel (Map<Número inteiro, Lista<String>> axisLabels, TensorBuffer tensorBuffer)
Cria um objeto TensorLabel capaz de rotular os eixos de tensores multidimensionais.
Parâmetros
axisLabels | Um mapa, cuja chave é o ID do eixo (a partir de 0) e o valor corresponde rótulos. Observação: o tamanho dos rótulos precisa ser o mesmo do tensor no eixo. |
---|---|
tensorBuffer | O TensorBuffer a ser rotulado. |
Gera
NullPointerException | se axisLabels ou tensorBuffer for nulo, ou qualquer
o valor em axisLabels é nulo. |
---|---|
IllegalArgumentException | se alguma chave em axisLabels estiver fora do intervalo (em comparação com
a forma de tensorBuffer , ou qualquer valor (rótulos) tiver tamanho diferente com tensorBuffer na dimensão especificada.
|
públicas TensorLabel (List<String> axisLabels, TensorBuffer tensorBuffer)
Cria um objeto TensorLabel capaz de rotular em um eixo de tensores multidimensionais.
Observação: os rótulos são aplicados no primeiro eixo cujo tamanho é maior que 1. Por exemplo, se
a forma do tensor for [1, 10, 3], os rótulos serão aplicados no eixo 1 (ID a partir de
0), e o tamanho de axisLabels
também precisa ser 10.
Parâmetros
axisLabels | Uma lista de rótulos, com tamanho igual ao do tensor em eixo a ser rotulado. |
---|---|
tensorBuffer | O TensorBuffer a ser rotulado. |
Métodos públicos
públicas Lista<Categoria> getCategoryList ()
Recebe uma lista de Category
do objeto TensorLabel
.
O eixo do identificador deve ser efetivamente o último eixo (que significa que cada subtensor
especificado por esse eixo deve ter um tamanho simples de 1), de modo que cada subtensor rotulado possa ser
convertida em uma pontuação de valor flutuante. Exemplo: uma TensorLabel
com o formato {2, 5, 3}
e o eixo 2 é válido. Se o eixo for 1 ou 0, ele não poderá ser convertido em Category
.
getMapWithFloatValue()
é uma alternativa, mas retorna um Map
como
o resultado.
Gera
IllegalStateException | se o tamanho de um subtensor em cada rótulo não for 1. |
---|
públicas Map<String, Float> getMapWithFloatValue ()
Recebe um mapa que mapeia o rótulo para flutuar. Permitir o mapeamento apenas no primeiro eixo com tamanho maior que 1, e o eixo deve ser efetivamente o último eixo (o que significa que cada subtensor especificados por este eixo devem ter um tamanho simples de 1).
getCategoryList()
é uma API alternativa para receber o resultado.
Gera
IllegalStateException | se o tamanho de um subtensor em cada rótulo não for 1. |
---|
públicas Map<String, TensorBuffer> getMapWithTensorBuffer ()
Recebe o mapa com um par do rótulo e o TensorBuffer correspondente. Permita apenas o mapeamento no primeiro eixo com tamanho maior que 1 no momento.