BERT 자연어 분류기 통합

Task Library BertNLClassifier API는 NLClassifier와 매우 유사합니다. 는 입력 텍스트를 여러 카테고리로 분류합니다. 단, 이 API는 BERT 관련 모델에 맞게 특별히 개발되었습니다. TFLite 모델 외부의 문장 조각 토큰화

BertNLClassifier API의 주요 특징

  • 단일 문자열을 입력으로 받아서 해당 문자열로 분류를 수행하고 <label, score="">가 출력됨 분류 결과로 도출됩니다.</label,>

  • 그래프 외부의 Word 중을 실행합니다. 또는 문장 토큰화에 대해 살펴봤습니다

지원되는 BertNLClassifier 모델

다음 모델은 BertNLClassifier API와 호환됩니다.

Java에서 추론 실행

1단계: Gradle 종속 항목 및 기타 설정 가져오기

.tflite 모델 파일을 Android 모듈의 assets 디렉터리에 복사합니다. 지정할 수도 있습니다 파일을 압축하지 않도록 지정합니다. 모듈의 build.gradle 파일에 TensorFlow Lite 라이브러리를 추가합니다.

android {
    // Other settings

    // Specify tflite file should not be compressed for the app apk
    aaptOptions {
        noCompress "tflite"
    }

}

dependencies {
    // Other dependencies

    // Import the Task Text Library dependency
    implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.4'
}

2단계: API를 사용하여 추론 실행

// Initialization
BertNLClassifierOptions options =
    BertNLClassifierOptions.builder()
        .setBaseOptions(BaseOptions.builder().setNumThreads(4).build())
        .build();
BertNLClassifier classifier =
    BertNLClassifier.createFromFileAndOptions(context, modelFile, options);

// Run inference
List<Category> results = classifier.classify(input);

출처 코드 를 참조하세요.

Swift에서 추론 실행

1단계: CocoaPods 가져오기

Podfile에 TensorFlowLiteTaskText 포드 추가하기

target 'MySwiftAppWithTaskAPI' do
  use_frameworks!
  pod 'TensorFlowLiteTaskText', '~> 0.4.4'
end

2단계: API를 사용하여 추론 실행

// Initialization
let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier(
      modelPath: bertModelPath)

// Run inference
let categories = bertNLClassifier.classify(text: input)

출처 코드 를 참조하세요.

C++에서 추론 실행

// Initialization
BertNLClassifierOptions options;
options.mutable_base_options()->mutable_model_file()->set_file_name(model_path);
std::unique_ptr<BertNLClassifier> classifier = BertNLClassifier::CreateFromOptions(options).value();

// Run inference with your input, `input_text`.
std::vector<core::Category> categories = classifier->Classify(input_text);

출처 코드 를 참조하세요.

Python에서 추론 실행

1단계: pip 패키지 설치

pip install tflite-support

2단계: 모델 사용

# Imports
from tflite_support.task import text

# Initialization
classifier = text.BertNLClassifier.create_from_file(model_path)

# Run inference
text_classification_result = classifier.classify(text)

출처 코드 BertNLClassifier 구성 옵션을 참조하세요.

결과 예시

다음은 MobileBert 모델을 사용할 수 없습니다.

입력: '매력적이지만 여정에 영향을 미치는 경우가 많습니다.'

출력:

category[0]: 'negative' : '0.00006'
category[1]: 'positive' : '0.99994'

간단한 CLI 데모 도구 BertNLClassifier 모델을 학습시킬 수 있습니다

모델 호환성 요구사항

BetNLClassifier API에는 필수 TFLite 모델이 포함된 TFLite 모델이 필요합니다. 메타데이터.

메타데이터는 다음 요구사항을 충족해야 합니다.

  • 워드피스/센서피스 토큰나이저의 input_process_units

  • 이름이 'ids', 'mask'인 입력 텐서 3개 및 'segment_ids' 출력은 tokenizer

  • float32 유형의 출력 텐서 1개, 라벨 파일 선택 가능 만약 라벨 파일이 첨부되었습니다. 파일은 라벨이 1개인 일반 텍스트 파일이어야 합니다. 라벨 수는 카테고리 수와 일치해야 합니다. 모델이 출력됩니다.