TensorLabel è un wrapper utente per TensorBuffers con etichette significative su un asse.
Ad esempio, un modello di classificazione delle immagini può avere un tensore di output con forma {1, 10}, dove 1 è la dimensione del batch e 10 è il numero di categorie. Infatti, sul 2° asse, potremmo etichettare ogni sottotensore con il nome o la descrizione di ogni categoria corrispondente. TensorLabel
potrebbe aiutare a convertire il semplice Tensor in TensorBuffer
in una mappa da
etichette predefinite a sottotensori. In questo caso, se vengono fornite 10 etichette per il 2° asse, TensorLabel
potrebbe convertire la mappa originale {1, 10} Tensor in una mappa di 10 elementi, ciascuno dei quali
è Tensor in forma {} (scalar). Esempio di utilizzo:
TensorBuffer outputTensor = ...; List<String> labels = FileUtil.loadLabels(context, labelFilePath); // labels the first axis with size greater than one TensorLabel labeled = new TensorLabel(labels, outputTensor); // If each sub-tensor has effectively size 1, we can directly get a float value Map<String, Float> probabilities = labeled.getMapWithFloatValue(); // Or get sub-tensors, when each sub-tensor has elements more than 1 Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
Nota: al momento supportiamo solo la conversione tensor-to-map per la prima etichetta con dimensioni maggiori di 1.
Costruttori pubblici
TensorLabel(Mappa<Integer, List<String>> axisLabels, TensorBuffer tensorBuffer)
Crea un oggetto TensorLabel in grado di etichettare sugli assi dei tensori multidimensionali.
|
|
TensorLabel(Elenco<String> axisLabels, TensorBuffer tensorBuffer)
Crea un oggetto TensorLabel in grado di etichettare su un asse di tensori multidimensionali.
|
Metodi pubblici
List<Category> |
getCategoryList()
Recupera un elenco di
Category dall'oggetto TensorLabel . |
Map<String, Float> |
getMapWithFloatValue()
Restituisce una mappa che mappa l'etichetta alla visualizzazione in virgola mobile.
|
Map<String, TensorBuffer> |
getMapWithTensorBuffer()
Restituisce la mappa con una coppia dell'etichetta e del TensorBuffer corrispondente.
|
Metodi ereditati
Costruttori pubblici
pubbliche TensorLabel (Mappa<Numero intero, Elenco<Stringa>> Etichette asse, TensorBuffer tensorBuffer)
Crea un oggetto TensorLabel in grado di etichettare sugli assi dei tensori multidimensionali.
Parametri
axisLabels | Una mappa la cui chiave è l'ID asse (a partire da 0) e il valore corrisponde alle etichette corrispondenti. Nota: la dimensione delle etichette deve corrispondere a quella del tensore su quell'asse. |
---|---|
tensorBuffer | Il TensorBuffer da etichettare. |
Lanci
NullPointerException | se axisLabels o tensorBuffer è nullo o se qualsiasi valore in axisLabels è nullo. |
---|---|
IllegalArgumentException | se una qualsiasi chiave in axisLabels non rientra nell'intervallo (rispetto alla forma di tensorBuffer oppure se qualsiasi valore (etichette) ha dimensioni diverse rispetto a tensorBuffer nella dimensione specificata.
|
Public TensorLabel (List<String> axisLabels, TensorBuffer tensorBuffer)
Crea un oggetto TensorLabel in grado di etichettare su un asse di tensori multidimensionali.
Nota: le etichette vengono applicate sul primo asse le cui dimensioni sono superiori a 1. Ad esempio, se la forma del tensore è [1, 10, 3], le etichette verranno applicate sull'asse 1 (ID a partire da 0) e anche la dimensione di axisLabels
dovrebbe essere 10.
Parametri
axisLabels | Un elenco di etichette, le cui dimensioni devono corrispondere a quelle del tensore sull'asse da etichettare. |
---|---|
tensorBuffer | Il TensorBuffer da etichettare. |
Metodi pubblici
Public Elenco<Categoria> getCategoryList ()
Recupera un elenco di Category
dall'oggetto TensorLabel
.
L'asse dell'etichetta dovrebbe essere effettivamente l'ultimo asse (il che significa che ogni sottotensore specificato da questo asse deve avere una dimensione fissa pari a 1), in modo che ogni sottotensore etichettato possa essere convertito in un punteggio con valore in virgola mobile. Esempio: un elemento TensorLabel
con forma {2, 5, 3}
e asse 2 è valido. Se l'asse è 1 o 0, non può essere convertito in un Category
.
getMapWithFloatValue()
è un'alternativa, ma restituisce Map
come risultato.
Lanci
IllegalStateException | se la dimensione di un subtensore su ciascuna etichetta è diversa da 1. |
---|
Public Mappa<String, Float> getMapWithFloatValue ()
Restituisce una mappa che mappa l'etichetta alla visualizzazione in virgola mobile. Consenti la mappatura solo sul primo asse con dimensioni maggiori di 1 e l'asse dovrebbe essere effettivamente l'ultimo asse (il che significa che ogni sottotensore specificato da questo asse deve avere una dimensione fissa pari a 1).
getCategoryList()
è un'API alternativa per ottenere il risultato.
Lanci
IllegalStateException | se la dimensione di un subtensore su ciascuna etichetta è diversa da 1. |
---|
Public Mappa<String, TensorBuffer> getMapWithTensorBuffer ()
Restituisce la mappa con una coppia dell'etichetta e del TensorBuffer corrispondente. Consenti la mappatura solo sul primo asse con dimensioni attualmente maggiori di 1.