API for creating and training a text classification model.
mediapipe_model_maker.text_classifier.TextClassifier(
model_spec: Any, label_names: Sequence[str], shuffle: bool
)
Args | |
---|---|
model_spec
|
Specification for the model. |
label_names
|
A list of label names for the classes. |
shuffle
|
Whether the dataset should be shuffled. |
Methods
create
@classmethod
create( train_data:
mediapipe_model_maker.text_classifier.Dataset
, validation_data:mediapipe_model_maker.text_classifier.Dataset
, options:mediapipe_model_maker.text_classifier.TextClassifierOptions
) -> 'TextClassifier'
Factory function that creates and trains a text classifier.
Note that train_data
and validation_data
are expected to share the same
label_names
since they should be split from the same dataset.
Args | |
---|---|
train_data
|
Training data. |
validation_data
|
Validation data. |
options
|
Options for creating and training the text classifier. |
Returns | |
---|---|
A text classifier. |
Raises | |
---|---|
ValueError if train_data and validation_data do not have the
same label_names or options contains an unknown supported_model
|
evaluate
evaluate(
data: mediapipe_model_maker.model_util.dataset.Dataset
,
batch_size: int = 32,
desired_precisions: Optional[Sequence[float]] = None,
desired_recalls: Optional[Sequence[float]] = None
) -> Any
Overrides Classifier.evaluate().
Args | |
---|---|
data
|
Evaluation dataset. Must be a TextClassifier Dataset. |
batch_size
|
Number of samples per evaluation step. |
desired_precisions
|
If specified, adds a RecallAtPrecision metric per desired_precisions[i] entry which tracks the recall given the constraint on precision. Only supported for binary classification. |
desired_recalls
|
If specified, adds a PrecisionAtRecall metric per desired_recalls[i] entry which tracks the precision given the constraint on recall. Only supported for binary classification. |
Returns | |
---|---|
The loss value and accuracy. |
Raises | |
---|---|
ValueError if data is not a TextClassifier Dataset.
|
export_labels
export_labels(
export_dir: str, label_filename: str = 'labels.txt'
)
Exports classification labels into a label file.
Args | |
---|---|
export_dir
|
The directory to save exported files. |
label_filename
|
File name to save labels model. The full export path is {export_dir}/{label_filename}. |
export_model
export_model(
model_name: str = 'model.tflite',
quantization_config: Optional[mediapipe_model_maker.quantization.QuantizationConfig
] = None
)
Converts and saves the model to a TFLite file with metadata included.
Note that only the TFLite file is needed for deployment. This function also saves a metadata.json file to the same directory as the TFLite file which can be used to interpret the metadata content in the TFLite file.
Args | |
---|---|
model_name
|
File name to save TFLite model with metadata. The full export path is {self._hparams.export_dir}/{model_name}. |
quantization_config
|
The configuration for model quantization. |
export_tflite
export_tflite(
export_dir: str,
tflite_filename: str = 'model.tflite',
quantization_config: Optional[mediapipe_model_maker.quantization.QuantizationConfig
] = None,
preprocess: Optional[Callable[..., bool]] = None
)
Converts the model to requested formats.
Args | |
---|---|
export_dir
|
The directory to save exported files. |
tflite_filename
|
File name to save TFLite model. The full export path is {export_dir}/{tflite_filename}. |
quantization_config
|
The configuration for model quantization. |
preprocess
|
A callable to preprocess the representative dataset for quantization. The callable takes three arguments in order: feature, label, and is_training. |
load_bert_classifier
@classmethod
load_bert_classifier( options:
mediapipe_model_maker.text_classifier.TextClassifierOptions
, saved_model_path: str, label_names: Sequence[str] ) -> 'TextClassifier'
save_model
save_model(
model_name: str = 'saved_model'
)
Saves the model in SavedModel format.
For more information, see https://www.tensorflow.org/guide/saved_model
Args | |
---|---|
model_name
|
Name of the saved model. |
summary
summary()
Prints a summary of the model.