L'API della libreria di attività BertNLClassifier
è molto simile all'API NLClassifier
che classifica il testo di input in diverse categorie, ad eccezione del fatto che questa API è
su misura per i modelli correlati a BERT che richiedono un testo
Tokenizzazioni di segmenti di frase al di fuori del modello TFLite.
Funzionalità chiave dell'API BertNLClassifier
Prende una singola stringa come input, esegue la classificazione con la stringa e restituisce <label, score=""> coppie come risultati di classificazione.</label,>
Esegue un testo fuori grafico o Fondamentale tokenizzazioni sul testo di input.
Modelli BertNLClassifier supportati
I seguenti modelli sono compatibili con l'API BertNLClassifier
.
Modelli BERT creati da TensorFlow Lite Model Maker per il testo Classificazione.
Modelli personalizzati che soddisfano la compatibilità dei modelli requisiti.
Esegui l'inferenza in Java
Passaggio 1: importa la dipendenza da Gradle e altre impostazioni
Copia il file del modello .tflite
nella directory degli asset del modulo per Android
in cui verrà eseguito il modello. Specifica che il file non deve essere compresso.
aggiungi la libreria TensorFlow Lite al file build.gradle
del modulo:
android {
// Other settings
// Specify tflite file should not be compressed for the app apk
aaptOptions {
noCompress "tflite"
}
}
dependencies {
// Other dependencies
// Import the Task Text Library dependency
implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.4'
}
Passaggio 2: esegui l'inferenza utilizzando l'API
// Initialization
BertNLClassifierOptions options =
BertNLClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setNumThreads(4).build())
.build();
BertNLClassifier classifier =
BertNLClassifier.createFromFileAndOptions(context, modelFile, options);
// Run inference
List<Category> results = classifier.classify(input);
Consulta la fonte codice per ulteriori dettagli.
Esegui l'inferenza in Swift
Passaggio 1: importa CocoaPods
Aggiungi il pod TensorFlowLiteTaskText in Podfile
target 'MySwiftAppWithTaskAPI' do
use_frameworks!
pod 'TensorFlowLiteTaskText', '~> 0.4.4'
end
Passaggio 2: esegui l'inferenza utilizzando l'API
// Initialization
let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier(
modelPath: bertModelPath)
// Run inference
let categories = bertNLClassifier.classify(text: input)
Consulta la fonte codice per ulteriori dettagli.
Esegui l'inferenza in C++
// Initialization
BertNLClassifierOptions options;
options.mutable_base_options()->mutable_model_file()->set_file_name(model_path);
std::unique_ptr<BertNLClassifier> classifier = BertNLClassifier::CreateFromOptions(options).value();
// Run inference with your input, `input_text`.
std::vector<core::Category> categories = classifier->Classify(input_text);
Consulta la fonte codice per ulteriori dettagli.
Esegui l'inferenza in Python
Passaggio 1: installa il pacchetto pip
pip install tflite-support
Passaggio 2: utilizzo del modello
# Imports
from tflite_support.task import text
# Initialization
classifier = text.BertNLClassifier.create_from_file(model_path)
# Run inference
text_classification_result = classifier.classify(text)
Consulta la fonte
codice
per altre opzioni di configurazione di BertNLClassifier
.
Risultati di esempio
Ecco un esempio di risultati di classificazione delle recensioni di film utilizzando il metodo MobileBert di Model Maker.
Input: "È un viaggio affascinante e che spesso influenza"
Output:
category[0]: 'negative' : '0.00006'
category[1]: 'positive' : '0.99994'
Prova il semplice strumento dimostrativo dell'interfaccia a riga di comando per BertNLClassifier con il tuo modello e dati di test.
Requisiti di compatibilità del modello
L'API BetNLClassifier
prevede un modello TFLite con modello TFLite obbligatorio
Metadati.
I metadati devono soddisfare i seguenti requisiti:
input_process_units per tokenizzatore di pezzi di testo/frasi
3 tensori di input con i nomi "ids", "mask" e "segment_id" per l'output il tokenizzatore
1 tensore di output di tipo float32, con un file di etichette allegato facoltativamente. Se il file dell'etichetta è allegato, il file deve essere un file di testo normale con una sola etichetta per riga e il numero di etichette deve corrispondere al numero di categorie come come output del modello.