Klasa sterownika do wnioskowania modelu za pomocą TensorFlow Lite.
Uwaga: jeśli nie potrzebujesz dostępu do Funkcje API poniżej – preferuję użycie InterpreterApi i InterpreterFactory zamiast bezpośredniego korzystania z narzędzia Interpreter.
Interpreter
zawiera wytrenowany model TensorFlow Lite, w którym operacje
są wykonywane na potrzeby wnioskowania na podstawie modelu.
Jeśli na przykład model przyjmuje tylko 1 dane wejściowe i zwraca tylko jedno dane wyjściowe:
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.run(input, output);
}
Jeśli model przyjmuje wiele danych wejściowych lub wyjściowych:
Object[] inputs = {input0, input1, ...};
Map<Integer, Object> map_of_indices_to_outputs = new HashMap<>();
FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4); // Float tensor, shape 3x2x4.
ith_output.order(ByteOrder.nativeOrder());
map_of_indices_to_outputs.put(i, ith_output);
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
}
Jeśli model przyjmuje lub generuje tensory ciągów:
String[] input = {"foo", "bar"}; // Input tensor shape is [2].
String[][] output = new String[3][2]; // Output tensor shape is [3, 2].
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(input, output);
}
Istnieje rozróżnienie między kształtem [] a kształtem[1]. Skalarny tensor ciągu znaków dane wyjściowe:
String[] input = {"foo"}; // Input tensor shape is [1].
ByteBuffer outputBuffer = ByteBuffer.allocate(OUTPUT_BYTES_SIZE); // Output tensor shape is [].
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(input, outputBuffer);
}
byte[] outputBytes = new byte[outputBuffer.remaining()];
outputBuffer.get(outputBytes);
// Below, the `charset` can be StandardCharsets.UTF_8.
String output = new String(outputBytes, charset);
Kolejność danych wejściowych i wyjściowych jest określana podczas konwertowania modelu TensorFlow na TensorFlowLite z modelem Toco, a także domyślne kształty danych wejściowych.
Jeśli dane wejściowe są podane jako (wielowymiarowe) tablice, odpowiednie tensory wejściowe będą
zostanie niejawnie przeskalowana odpowiednio do kształtu tej tablicy. Gdy dane wejściowe są podane jako Buffer
nie jest wprowadzana żadna niejawna zmiana rozmiaru; element wywołujący musi zapewnić, że rozmiar w bajtach to Buffer
pasuje do danego tensora lub najpierw zmienia rozmiar tensora za pomocą funkcji resizeInput(int, int[])
. Informacje o kształcie i typie Tensor można uzyskać za pomocą klasy Tensor
, dostępnej za pośrednictwem getInputTensor(int)
i getOutputTensor(int)
.
OSTRZEŻENIE: instancje Interpreter
nie są bezpieczne w przypadku wątków. Interpreter
posiada zasoby, które muszą być wyraźnie zwolnione przez wywołanie metody close()
Biblioteka TFLite jest oparta na interfejsie NDK API 19. Może działać w przypadku interfejsów API Androida poniżej 19, ale nie jest to gwarantowane.
Zagnieżdżone klasy
klasa | Interpreter.Options | Klasa opcji do kontrolowania działania interpretera czasu działania. |
Konstruktorki publiczne
Interpreter(plik modelu Plik, opcje Interpreter.Options)
Inicjuje
Interpreter i określa opcje dostosowywania działania interpretera. |
|
Interpreter(ByteBuffer byteBuffer)
Inicjuje właściwość
Interpreter z elementem ByteBuffer pliku modelu. |
|
Interpreter obrazu(ByteBuffer byteBuffer, opcje Interpreter.Options)
Inicjuje obiekt
Interpreter z użyciem elementu ByteBuffer pliku modelu i zbioru
niestandardowy Interpreter.Options . |
Metody publiczne
nieważne |
allocateTensors()
W razie potrzeby jawnie aktualizuje przydziały wszystkich tensorów.
|
nieważne |
close()
Zwolnij zasoby powiązane z instancją
InterpreterApi . |
int, |
getInputIndex(nazwa opcji ciąg znaków)
Pobiera indeks danych wejściowych o nazwie operacji dla danych wejściowych.
|
Tensor |
getInputTensor(int inputIndex)
Pobiera Tensor powiązany z podanym indeksem danych wejściowych.
|
int, |
getInputTensorCount()
Pobiera liczbę tensorów wejściowych.
|
Tensor |
getInputTensorFromSignature(Ciąg nazwa wejścia, CiągSignatureKey)
Pobiera Tensor powiązany z podaną nazwą danych wejściowych i nazwą metody podpisu.
|
Długi |
getLastNativeInferenceDurationNanoseconds()
Zwraca czas wnioskowania natywnego.
|
int, |
getOutputIndex(opName ciągu)
Pobiera indeks danych wyjściowych o nazwie operacji w danych wyjściowych.
|