TensorLabel est un wrapper d'utilitaire pour les TensorBuffers avec des libellés significatifs sur un axe.
Par exemple, un modèle de classification d'images peut avoir un Tensor de sortie ayant la forme {1, 10}, où 1 correspond à la taille de lot et 10 au nombre de catégories. En fait, sur le deuxième axe, nous pourrions étiqueter chaque sous-Tensor avec le nom ou la description de chaque catégorie correspondante. TensorLabel
pourrait aider à convertir le Tensor simple de TensorBuffer
en un mappage entre étiquettes prédéfinies et sous-Tensors. Dans ce cas, si vous fournissez 10 étiquettes pour le 2e axe, TensorLabel
pourrait convertir le Tensor {1, 10} d'origine en une carte à 10 éléments, dont chaque valeur est un Tensor de forme {} (scalaire). Exemple d'utilisation :
TensorBuffer outputTensor = ...; List<String> labels = FileUtil.loadLabels(context, labelFilePath); // labels the first axis with size greater than one TensorLabel labeled = new TensorLabel(labels, outputTensor); // If each sub-tensor has effectively size 1, we can directly get a float value Map<String, Float> probabilities = labeled.getMapWithFloatValue(); // Or get sub-tensors, when each sub-tensor has elements more than 1 Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
Remarque: Actuellement, nous n'acceptons que la conversion Tensor-to-map pour le premier libellé dont la taille est supérieure à 1.
Constructeurs publics
TensorLabel(Map<Integer, List<String>> axisLabel, TensorBuffer tensorBuffer)
Crée un objet TensorLabel capable d'étiqueter les axes des Tensors multidimensionnels.
|
|
TensorLabel(List<String> axisLabel, TensorBuffer tensorBuffer)
Crée un objet TensorLabel capable d'étiqueter des Tensors multidimensionnels sur un axe.
|
Méthodes publiques
List<Category> |
getCategoryList()
Récupère une liste d'
Category à partir de l'objet TensorLabel . |
Map<String, Float> |
getMapWithFloatValue()
Récupère une carte qui mappe le libellé sur "float".
|
Map<String, TensorBuffer> |
getMapWithTensorBuffer()
Récupère la carte avec une paire de l'étiquette et le TensorBuffer correspondant.
|
Méthodes héritées
Constructeurs publics
public TensorLabel (Map<Integer, List<String> axisÉtiquettes, TensorBuffer tensorBuffer)
Crée un objet TensorLabel capable d'étiqueter les axes des Tensors multidimensionnels.
Paramètres
axisLabels | Une carte, dont la clé est l'identifiant de l'axe (à partir de 0) et la valeur est les libellés correspondants. Remarque: La taille des étiquettes doit être identique à celle du Tensor sur cet axe. |
---|---|
tensorBuffer | TensorBuffer à étiqueter. |
Génère
NullPointerException | si axisLabels ou tensorBuffer est nul, ou si toute valeur dans axisLabels est nulle. |
---|---|
IllegalArgumentException | si une clé dans axisLabels est en dehors de la plage (par rapport à la forme de tensorBuffer ), ou si toute valeur (libellés) a une taille différente avec le tensorBuffer de la dimension donnée.
|
public TensorLabel (List<String> axisLabel, TensorBuffer tensorBuffer)
Crée un objet TensorLabel capable d'étiqueter des Tensors multidimensionnels sur un axe.
Remarque: Les libellés sont appliqués sur le premier axe dont la taille est supérieure à 1. Par exemple, si la forme du Tensor est [1, 10, 3], les étiquettes seront appliquées sur l'axe 1 (identifiant à partir de 0) et la taille de axisLabels
devrait également être égale à 10.
Paramètres
axisLabels | Une liste d'étiquettes dont la taille doit être identique à celle du Tensor sur l'axe à étiqueter. |
---|---|
tensorBuffer | TensorBuffer à étiqueter. |
Méthodes publiques
public List<Category> getCategoryList ()
Récupère une liste d'Category
à partir de l'objet TensorLabel
.
L'axe de l'étiquette doit être effectivement le dernier (ce qui signifie que chaque sous-Tensor spécifié par cet axe doit avoir une taille plate de 1), de sorte que chaque sous-tenseur étiqueté puisse être converti en un score de valeur flottante. Exemple: Un TensorLabel
ayant la forme {2, 5, 3}
et l'axe 2 est valide. Si l'axe est 1 ou 0, il ne peut pas être converti en Category
.
getMapWithFloatValue()
est une alternative, mais renvoie Map
comme résultat.
Génère
IllegalStateException | si la taille d'un sous-Tensor sur chaque étiquette n'est pas égale à 1. |
---|
public Map<String, Float> getMapWithFloatValue ()
Récupère une carte qui mappe le libellé sur "float". N'autorisez le mappage que sur le premier axe dont la taille est supérieure à 1, et l'axe doit être effectivement le dernier (ce qui signifie que chaque sous-tenseur spécifié par cet axe doit avoir une taille plate de 1).
getCategoryList()
est une autre API permettant d'obtenir le résultat.
Génère
IllegalStateException | si la taille d'un sous-Tensor sur chaque étiquette n'est pas égale à 1. |
---|
public Map<String, TensorBuffer> getMapWithTensorBuffer ()
Récupère la carte avec une paire de l'étiquette et le TensorBuffer correspondant. N'autorisez le mappage que sur le premier axe dont la taille est actuellement supérieure à 1.