Task Library BertNLClassifier
API는 NLClassifier
와 매우 유사합니다.
는 입력 텍스트를 여러 카테고리로 분류합니다. 단, 이 API는
BERT 관련 모델에 맞게 특별히 개발되었습니다.
TFLite 모델 외부의 문장 조각 토큰화
BertNLClassifier API의 주요 특징
단일 문자열을 입력으로 받아서 해당 문자열로 분류를 수행하고 <label, score="">가 출력됨 분류 결과로 도출됩니다.</label,>
지원되는 BertNLClassifier 모델
다음 모델은 BertNLClassifier
API와 호환됩니다.
모델 호환성을 충족하는 커스텀 모델 요구사항을 충족하는 방법을 안내합니다.
Java에서 추론 실행
1단계: Gradle 종속 항목 및 기타 설정 가져오기
.tflite
모델 파일을 Android 모듈의 assets 디렉터리에 복사합니다.
지정할 수도 있습니다 파일을 압축하지 않도록 지정합니다.
모듈의 build.gradle
파일에 TensorFlow Lite 라이브러리를 추가합니다.
android {
// Other settings
// Specify tflite file should not be compressed for the app apk
aaptOptions {
noCompress "tflite"
}
}
dependencies {
// Other dependencies
// Import the Task Text Library dependency
implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.4'
}
2단계: API를 사용하여 추론 실행
// Initialization
BertNLClassifierOptions options =
BertNLClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setNumThreads(4).build())
.build();
BertNLClassifier classifier =
BertNLClassifier.createFromFileAndOptions(context, modelFile, options);
// Run inference
List<Category> results = classifier.classify(input);
출처 코드 를 참조하세요.
Swift에서 추론 실행
1단계: CocoaPods 가져오기
Podfile에 TensorFlowLiteTaskText 포드 추가하기
target 'MySwiftAppWithTaskAPI' do
use_frameworks!
pod 'TensorFlowLiteTaskText', '~> 0.4.4'
end
2단계: API를 사용하여 추론 실행
// Initialization
let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier(
modelPath: bertModelPath)
// Run inference
let categories = bertNLClassifier.classify(text: input)
출처 코드 를 참조하세요.
C++에서 추론 실행
// Initialization
BertNLClassifierOptions options;
options.mutable_base_options()->mutable_model_file()->set_file_name(model_path);
std::unique_ptr<BertNLClassifier> classifier = BertNLClassifier::CreateFromOptions(options).value();
// Run inference with your input, `input_text`.
std::vector<core::Category> categories = classifier->Classify(input_text);
출처 코드 를 참조하세요.
Python에서 추론 실행
1단계: pip 패키지 설치
pip install tflite-support
2단계: 모델 사용
# Imports
from tflite_support.task import text
# Initialization
classifier = text.BertNLClassifier.create_from_file(model_path)
# Run inference
text_classification_result = classifier.classify(text)
출처
코드
BertNLClassifier
구성 옵션을 참조하세요.
결과 예시
다음은 MobileBert 모델을 사용할 수 없습니다.
입력: '매력적이지만 여정에 영향을 미치는 경우가 많습니다.'
출력:
category[0]: 'negative' : '0.00006'
category[1]: 'positive' : '0.99994'
간단한 CLI 데모 도구 BertNLClassifier 모델을 학습시킬 수 있습니다
모델 호환성 요구사항
BetNLClassifier
API에는 필수 TFLite 모델이 포함된 TFLite 모델이 필요합니다.
메타데이터.
메타데이터는 다음 요구사항을 충족해야 합니다.
워드피스/센서피스 토큰나이저의 input_process_units
이름이 'ids', 'mask'인 입력 텐서 3개 및 'segment_ids' 출력은 tokenizer
float32 유형의 출력 텐서 1개, 라벨 파일 선택 가능 만약 라벨 파일이 첨부되었습니다. 파일은 라벨이 1개인 일반 텍스트 파일이어야 합니다. 라벨 수는 카테고리 수와 일치해야 합니다. 모델이 출력됩니다.