View source
on GitHub
A class for training and validation of a Keras model.
View
source
@abc.abstractmethod
train(
model_fn: Callable[[], tf.keras.Model],
train_ds_provider: DatasetProvider,
*,
epochs: int = 1,
valid_ds_provider: Optional[DatasetProvider] = None
) -> tf.keras.Model
Trains a tf.keras.Model
with optional validation.
Args |
model_fn
|
Returns a tf.keras.Model for use in training and validation.
|
train_ds_provider
|
A DatasetProvider for training. The items of the
tf.data.Dataset are pairs (graph_tensor, label) that represent one
batch of per-replica training inputs after
GraphTensor.merge_batch_to_components() has been applied.
|
epochs
|
The epochs to train.
|
valid_ds_provider
|
A DatasetProvider for validation. The items of the
tf.data.Dataset are pairs (graph_tensor, label) that represent one
batch of per-replica training inputs after
GraphTensor.merge_batch_to_components() has been applied.
|
Returns |
A trained tf.keras.Model .
|