Classe driver per guidare l'inferenza del modello con TensorFlow Lite.
Nota. Se non hai bisogno di accedere a nessun elemento "sperimentale" le funzionalità API riportate di seguito, preferiamo utilizzare InterpreterApi e Interpreterfabbrica anziché utilizzare direttamente lo strumento Interpreter.
Un Interpreter
incapsula un modello TensorFlow Lite preaddestrato, in cui le operazioni
vengono eseguite per l'inferenza del modello.
Ad esempio, se un modello accetta un solo input e restituisce solo un output:
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.run(input, output);
}
Se un modello accetta più input o output:
Object[] inputs = {input0, input1, ...};
Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4); // Float tensor, shape 3x2x4.
ith_output.order(ByteOrder.nativeOrder());
map_of_indices_to_outputs.put(i, ith_output);
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
}
Se un modello prende o produce tensori di stringa:
String[] input = {"foo", "bar"}; // Input tensor shape is [2].
String[][] output = new String[3][2]; // Output tensor shape is [3, 2].
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(input, output);
}
Tieni presente che esiste una distinzione tra forma [] e forma[1]. Per tensori di stringa scalare genera:
String[] input = {"foo"}; // Input tensor shape is [1].
ByteBuffer outputBuffer = ByteBuffer.allocate(OUTPUT_BYTES_SIZE); // Output tensor shape is [].
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(input, outputBuffer);
}
byte[] outputBytes = new byte[outputBuffer.remaining()];
outputBuffer.get(outputBytes);
// Below, the `charset` can be StandardCharsets.UTF_8.
String output = new String(outputBytes, charset);
Gli ordini di input e output vengono determinati durante la conversione del modello TensorFlow in TensorFlowLite modello con Toco, così come le forme predefinite degli input.
Quando gli input vengono forniti come array (multidimensionali), i tensori di input corrispondenti
ridimensionato implicitamente in base alla forma dell'array. Quando gli input vengono forniti come Buffer
tra i due tipi, il ridimensionamento
implicito non avviene; il chiamante deve assicurarsi che la dimensione in byte di Buffer
corrisponde a quello del tensore corrispondente oppure ridimensiona il tensore tramite resizeInput(int, int[])
. È possibile ottenere informazioni su forma e tipo di Tensor tramite la classe Tensor
, disponibile tramite getInputTensor(int)
e getOutputTensor(int)
.
ATTENZIONE:Interpreter
istanze non sono sicure per i thread. Interpreter
possiede risorse che devono essere liberate esplicitamente richiamando close()
La libreria TFLite è basata sull'API NDK 19. Può funzionare con livelli API Android inferiori a 19, ma non è garantito.
Classi nidificate
classe | Interpreter.Options | Una classe di opzioni per controllare il comportamento dell'interprete di runtime. |
Costruttori pubblici
Interpreter(opzioni File modelFile, Interpreter.Options)
Inizializza un
Interpreter e specifica le opzioni per personalizzare il comportamento dell'interprete. |
|
Interpreter(ByteBuffer byteBuffer)
Inizializza un
Interpreter con un ByteBuffer di un file del modello. |
|
Interpreter(opzioni ByteBuffer byteBuffer, Interpreter.Options)
Inizializza un
Interpreter con un ByteBuffer di un file di modello e un set di
Interpreter.Options personalizzato. |
Metodi pubblici
null |
allocateTensors()
Aggiorna in modo esplicito le allocazioni per tutti i tensori, se necessario.
|
null |
close()
Rilascia le risorse associate all'istanza
InterpreterApi . |
int | |
Tensor |
getInputTensor(int inputIndex)
Recupera il Tensor associato all'indice di input fornito.
|
int |
getInputTensorCount()
Restituisce il numero di tensori di input.
|
Tensor |
getInputTensorFromSignature(String inputName, String signatureKey)
Recupera il Tensor associato al nome di input fornito e al nome del metodo di firma.
|
Lunga |
getLastNativeInferenceDurationNanoseconds()
Restituisce i tempi di inferenza nativa.
|
int |
getOutputIndex(Stringa opName)
Ottiene l'indice di un output dato il nome dell'operazione dell'output.
|
Tensor |
getOutputTensor(int outputIndex)
Recupera il Tensor associato all'indice di output fornito.
|
int |
getOutputTensorCount()
Ottiene il numero di tensori di output.
|
Tensor |