Task Library BertNLClassifier
API는 NLClassifier
와 매우 유사합니다.
는 입력 텍스트를 여러 카테고리로 분류합니다. 단, 이 API는
BERT 관련 모델에 맞게 특별히 개발되었습니다.
TFLite 모델 외부의 문장 조각 토큰화
BertNLClassifier API의 주요 특징
단일 문자열을 입력으로 받아서 해당 문자열로 분류를 수행하고 <label, score="">가 출력됨 분류 결과로 도출됩니다.</label,>
지원되는 BertNLClassifier 모델
다음 모델은 BertNLClassifier
API와 호환됩니다.
모델 호환성을 충족하는 커스텀 모델 요구사항을 충족하는 방법을 안내합니다.
Java에서 추론 실행
1단계: Gradle 종속 항목 및 기타 설정 가져오기
.tflite
모델 파일을 Android 모듈의 assets 디렉터리에 복사합니다.
지정할 수도 있습니다 파일을 압축하지 않도록 지정합니다.
모듈의 build.gradle
파일에 TensorFlow Lite 라이브러리를 추가합니다.
android {
// Other settings
// Specify tflite file should not be compressed for the app apk
aaptOptions {
noCompress "tflite"
}
}
dependencies {
// Other dependencies
// Import the Task Text Library dependency
implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.4'
}
2단계: API를 사용하여 추론 실행
// Initialization
BertNLClassifierOptions options =
BertNLClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setNumThreads(4).build())
.build();
BertNLClassifier classifier =
BertNLClassifier.createFromFileAndOptions(context, modelFile, options);
// Run inference
List<Category> results = classifier.classify(input);
출처 코드 를 참조하세요.
Swift에서 추론 실행
1단계: CocoaPods 가져오기
Podfile에 TensorFlowLiteTaskText 포드 추가하기
target 'MySwiftAppWithTaskAPI' do
use_frameworks!
pod 'TensorFlowLiteTaskText', '~> 0.4.4'
end
2단계: API를 사용하여 추론 실행
// Initialization
let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier(
modelPath: bertModelPath)
// Run inference
let categories = bertNLClassifier.classify(text: input)
출처 코드 를 참조하세요.
C++에서 추론 실행
// Initialization
BertNLClassifierOptions options;
options.mutable_base_options()->mutable_model_file()->set_file_name(model_path);
std::unique_ptr<BertNLClassifier> classifier = BertNLClassifier::CreateFromOptions(options).value();
// Run inference with your input, `input_text`.
std::vector<core::Category> categories = classifier->Classify(input_text);
출처 코드 를 참조하세요.
Python에서 추론 실행
1단계: pip 패키지 설치
pip install tflite-support
2단계: 모델 사용
# Imports
from tflite_support.task import text
# Initialization
classifier = text.BertNLClassifier.create_from_file(model_path)
# Run inference
text_classification_result = classifier.classify(text)
출처
코드
BertNLClassifier
구성 옵션을 참조하세요.
결과 예시
다음은 MobileBert 모델을 사용할 수 없습니다.
입력: '매력적이지만 여정에 영향을 미치는 경우가 많습니다.'
출력:
category[0]: 'negative' : '0.00006'
category[1]: 'positive' : '0.99994'
간단한 CLI 데모 도구 BertNLClassifier 모델을 학습시킬 수 있습니다
모델 호환성 요구사항
BetNLClassifier
API에는 필수 TFLite 모델이 포함된 TFLite 모델이 필요합니다.
메타데이터.
메타데이터는 다음 요구사항을 충족해야 합니다.
워드피스/센서피스 토큰나이저의 input_process_units
이름이 'ids', 'mask'인 입력 텐서 3개 및 'segment_ids' 출력은 tokenizer
float32 유형의 출력 텐서 1개, 라벨 파일 선택 가능 만약 라벨 파일이 첨부되어 있다면 파일은 라벨이 1개인 일반 텍스트 파일이어야 합니다. 라벨 수는 카테고리 수와 일치해야 합니다. 모델이 출력됩니다.