The hyperparameters for training image classifiers.
Inherits From: BaseHParams
mediapipe_model_maker.image_classifier.HParams(
learning_rate: float = 0.001,
batch_size: int = 2,
epochs: int = 10,
steps_per_epoch: Optional[int] = None,
class_weights: Optional[Mapping[int, float]] = None,
shuffle: bool = False,
repeat: bool = False,
export_dir: str = tempfile.mkdtemp(),
distribution_strategy: str = 'off',
num_gpus: int = 0,
tpu: str = '',
do_fine_tuning: bool = False,
l1_regularizer: float = 0.0,
l2_regularizer: float = 0.0001,
label_smoothing: float = 0.1,
do_data_augmentation: bool = True,
decay_samples: int = (10000 * 256),
warmup_epochs: int = 2,
checkpoint_frequency: int = 1,
one_hot: bool = True,
multi_labels: bool = False
)
Attributes |
learning_rate
|
Learning rate to use for gradient descent training.
|
batch_size
|
Batch size for training.
|
epochs
|
Number of training iterations over the dataset.
|
do_fine_tuning
|
If true, the base module is trained together with the
classification layer on top.
|
l1_regularizer
|
A regularizer that applies a L1 regularization penalty.
|
l2_regularizer
|
A regularizer that applies a L2 regularization penalty.
|
label_smoothing
|
Amount of label smoothing to apply. See tf.keras.losses for
more details.
|
do_data_augmentation
|
A boolean controlling whether the training dataset is
augmented by randomly distorting input images, including random cropping,
flipping, etc. See utils.image_preprocessing documentation for details.
|
decay_samples
|
Number of training samples used to calculate the decay steps
and create the training optimizer.
|
warmup_steps
|
Number of warmup steps for a linear increasing warmup schedule
on learning rate. Used to set up warmup schedule by model_util.WarmUp.
|
checkpoint_frequency
|
Frequency to save checkpoint.
|
one_hot
|
Whether the label data is score input or one-hot.
|
multi_labels
|
Whether the model predict multi labels.
|
steps_per_epoch
|
Dataclass field
|
class_weights
|
Dataclass field
|
shuffle
|
Dataclass field
|
repeat
|
Dataclass field
|
export_dir
|
Dataclass field
|
distribution_strategy
|
Dataclass field
|
num_gpus
|
Dataclass field
|
tpu
|
Dataclass field
|
warmup_epochs
|
Dataclass field
|
Methods
get_strategy
View source
get_strategy()
__eq__
__eq__(
other
)
Class Variables |
batch_size
|
2
|
checkpoint_frequency
|
1
|
class_weights
|
None
|
decay_samples
|
2560000
|
distribution_strategy
|
'off'
|
do_data_augmentation
|
True
|
do_fine_tuning
|
False
|
epochs
|
10
|
export_dir
|
'/tmpfs/tmp/tmpnt_h4p9w'
|
l1_regularizer
|
0.0
|
l2_regularizer
|
0.0001
|
label_smoothing
|
0.1
|
learning_rate
|
0.001
|
multi_labels
|
False
|
num_gpus
|
0
|
one_hot
|
True
|
repeat
|
False
|
shuffle
|
False
|
steps_per_epoch
|
None
|
tpu
|
''
|
warmup_epochs
|
2
|