TensorLabel

classe publique TensorLabel

TensorLabel est un wrapper d'utilitaire pour les TensorBuffers avec des libellés significatifs sur un axe.

Par exemple, un modèle de classification d'images peut avoir un Tensor de sortie ayant la forme {1, 10}, où 1 correspond à la taille de lot et 10 au nombre de catégories. En fait, sur le deuxième axe, nous pourrions étiqueter chaque sous-Tensor avec le nom ou la description de chaque catégorie correspondante. TensorLabel pourrait aider à convertir le Tensor simple de TensorBuffer en un mappage entre étiquettes prédéfinies et sous-Tensors. Dans ce cas, si vous fournissez 10 étiquettes pour le 2e axe, TensorLabel pourrait convertir le Tensor {1, 10} d'origine en une carte à 10 éléments, dont chaque valeur est un Tensor de forme {} (scalaire). Exemple d'utilisation :

   TensorBuffer outputTensor = ...;
   List<String> labels = FileUtil.loadLabels(context, labelFilePath);
   // labels the first axis with size greater than one
   TensorLabel labeled = new TensorLabel(labels, outputTensor);
   // If each sub-tensor has effectively size 1, we can directly get a float value
   Map<String, Float> probabilities = labeled.getMapWithFloatValue();
   // Or get sub-tensors, when each sub-tensor has elements more than 1
   Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
 

Remarque: Actuellement, nous n'acceptons que la conversion Tensor-to-map pour le premier libellé dont la taille est supérieure à 1.

Constructeurs publics

TensorLabel(Map<IntegerList<String>> axisLabel, TensorBuffer tensorBuffer)
Crée un objet TensorLabel capable d'étiqueter les axes des Tensors multidimensionnels.
TensorLabel(List<String> axisLabel, TensorBuffer tensorBuffer)
Crée un objet TensorLabel capable d'étiqueter des Tensors multidimensionnels sur un axe.

Méthodes publiques

List<Category>
getCategoryList()
Récupère une liste d'Category à partir de l'objet TensorLabel.
Map<StringFloat>
getMapWithFloatValue()
Récupère une carte qui mappe le libellé sur "float".
Map<StringTensorBuffer>
getMapWithTensorBuffer()
Récupère la carte avec une paire de l'étiquette et le TensorBuffer correspondant.

Méthodes héritées

Constructeurs publics

public TensorLabel (Map<IntegerList<String> axisÉtiquettes, TensorBuffer tensorBuffer)

Crée un objet TensorLabel capable d'étiqueter les axes des Tensors multidimensionnels.

Paramètres
axisLabels Une carte, dont la clé est l'identifiant de l'axe (à partir de 0) et la valeur est les libellés correspondants. Remarque: La taille des étiquettes doit être identique à celle du Tensor sur cet axe.
tensorBuffer TensorBuffer à étiqueter.
Génère
NullPointerException si axisLabels ou tensorBuffer est nul, ou si toute valeur dans axisLabels est nulle.
IllegalArgumentException si une clé dans axisLabels est en dehors de la plage (par rapport à la forme de tensorBuffer), ou si toute valeur (libellés) a une taille différente avec le tensorBuffer de la dimension donnée.

public TensorLabel (List<String> axisLabel, TensorBuffer tensorBuffer)

Crée un objet TensorLabel capable d'étiqueter des Tensors multidimensionnels sur un axe.

Remarque: Les libellés sont appliqués sur le premier axe dont la taille est supérieure à 1. Par exemple, si la forme du Tensor est [1, 10, 3], les étiquettes seront appliquées sur l'axe 1 (identifiant à partir de 0) et la taille de axisLabels devrait également être égale à 10.

Paramètres
axisLabels Une liste d'étiquettes dont la taille doit être identique à celle du Tensor sur l'axe à étiqueter.
tensorBuffer TensorBuffer à étiqueter.

Méthodes publiques

public List<Category> getCategoryList ()

Récupère une liste d'Category à partir de l'objet TensorLabel.

L'axe de l'étiquette doit être effectivement le dernier (ce qui signifie que chaque sous-Tensor spécifié par cet axe doit avoir une taille plate de 1), de sorte que chaque sous-tenseur étiqueté puisse être converti en un score de valeur flottante. Exemple: Un TensorLabel ayant la forme {2, 5, 3} et l'axe 2 est valide. Si l'axe est 1 ou 0, il ne peut pas être converti en Category.

getMapWithFloatValue() est une alternative, mais renvoie Map comme résultat.

Génère
IllegalStateException si la taille d'un sous-Tensor sur chaque étiquette n'est pas égale à 1.

public Map<StringFloat> getMapWithFloatValue ()

Récupère une carte qui mappe le libellé sur "float". N'autorisez le mappage que sur le premier axe dont la taille est supérieure à 1, et l'axe doit être effectivement le dernier (ce qui signifie que chaque sous-tenseur spécifié par cet axe doit avoir une taille plate de 1).

getCategoryList() est une autre API permettant d'obtenir le résultat.

Génère
IllegalStateException si la taille d'un sous-Tensor sur chaque étiquette n'est pas égale à 1.

public Map<StringTensorBuffer> getMapWithTensorBuffer ()

Récupère la carte avec une paire de l'étiquette et le TensorBuffer correspondant. N'autorisez le mappage que sur le premier axe dont la taille est actuellement supérieure à 1.