API for creating and training a text classification model.
mediapipe_model_maker.text_classifier.TextClassifier(
    model_spec: Any, label_names: Sequence[str], shuffle: bool
)
| Args | 
|---|
| model_spec | Specification for the model. | 
| label_names | A list of label names for the classes. | 
| shuffle | Whether the dataset should be shuffled. | 
Methods
create
View source
@classmethod
create(
    train_data: mediapipe_model_maker.text_classifier.Dataset,
    validation_data: mediapipe_model_maker.text_classifier.Dataset,
    options: mediapipe_model_maker.text_classifier.TextClassifierOptions
) -> 'TextClassifier'
Factory function that creates and trains a text classifier.
Note that train_data and validation_data are expected to share the same
label_names since they should be split from the same dataset.
| Args | 
|---|
| train_data | Training data. | 
| validation_data | Validation data. | 
| options | Options for creating and training the text classifier. | 
| Returns | 
|---|
| A text classifier. | 
| Raises | 
|---|
| ValueError if train_dataandvalidation_datado not have the
same label_names oroptionscontains an unknownsupported_model | 
evaluate
View source
evaluate(
    data: mediapipe_model_maker.model_util.dataset.Dataset,
    batch_size: int = 32,
    desired_precisions: Optional[Sequence[float]] = None,
    desired_recalls: Optional[Sequence[float]] = None
) -> Any
Overrides Classifier.evaluate().
| Args | 
|---|
| data | Evaluation dataset. Must be a TextClassifier Dataset. | 
| batch_size | Number of samples per evaluation step. | 
| desired_precisions | If specified, adds a RecallAtPrecision metric per
desired_precisions[i] entry which tracks the recall given the constraint
on precision. Only supported for binary classification. | 
| desired_recalls | If specified, adds a PrecisionAtRecall metric per
desired_recalls[i] entry which tracks the precision given the constraint
on recall. Only supported for binary classification. | 
| Returns | 
|---|
| The loss value and accuracy. | 
| Raises | 
|---|
| ValueError if datais not a TextClassifier Dataset. | 
export_labels
View source
export_labels(
    export_dir: str, label_filename: str = 'labels.txt'
)
Exports classification labels into a label file.
| Args | 
|---|
| export_dir | The directory to save exported files. | 
| label_filename | File name to save labels model. The full export path is
{export_dir}/{label_filename}. | 
export_model
View source
export_model(
    model_name: str = 'model.tflite',
    quantization_config: Optional[mediapipe_model_maker.quantization.QuantizationConfig] = None
)
Converts and saves the model to a TFLite file with metadata included.
Note that only the TFLite file is needed for deployment. This function also
saves a metadata.json file to the same directory as the TFLite file which
can be used to interpret the metadata content in the TFLite file.
| Args | 
|---|
| model_name | File name to save TFLite model with metadata. The full export
path is {self._hparams.export_dir}/{model_name}. | 
| quantization_config | The configuration for model quantization. | 
export_tflite
View source
export_tflite(
    export_dir: str,
    tflite_filename: str = 'model.tflite',
    quantization_config: Optional[mediapipe_model_maker.quantization.QuantizationConfig] = None,
    preprocess: Optional[Callable[..., bool]] = None
)
Converts the model to requested formats.
| Args | 
|---|
| export_dir | The directory to save exported files. | 
| tflite_filename | File name to save TFLite model. The full export path is
{export_dir}/{tflite_filename}. | 
| quantization_config | The configuration for model quantization. | 
| preprocess | A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature,
label, and is_training. | 
load_bert_classifier
View source
@classmethod
load_bert_classifier(
    options: mediapipe_model_maker.text_classifier.TextClassifierOptions,
    saved_model_path: str,
    label_names: Sequence[str]
) -> 'TextClassifier'
save_model
View source
save_model(
    model_name: str = 'saved_model'
)
Saves the model in SavedModel format.
For more information, see https://www.tensorflow.org/guide/saved_model
| Args | 
|---|
| model_name | Name of the saved model. | 
summary
View source
summary()
Prints a summary of the model.