Skip to content

Latest commit

 

History

History
100 lines (81 loc) · 2.49 KB

File metadata and controls

100 lines (81 loc) · 2.49 KB

runner.Trainer

View source on GitHub

A class for training and validation of a Keras model.

Attributes

model_dir
strategy

Methods

train

View source

@abc.abstractmethod
train(
    model_fn: Callable[[], tf.keras.Model],
    train_ds_provider: DatasetProvider,
    *,
    epochs: int = 1,
    valid_ds_provider: Optional[DatasetProvider] = None
) -> tf.keras.Model

Trains a tf.keras.Model with optional validation.

Args
model_fn Returns a tf.keras.Model for use in training and validation.
train_ds_provider A DatasetProvider for training. The items of the tf.data.Dataset are pairs (graph_tensor, label) that represent one batch of per-replica training inputs after GraphTensor.merge_batch_to_components() has been applied.
epochs The epochs to train.
valid_ds_provider A DatasetProvider for validation. The items of the tf.data.Dataset are pairs (graph_tensor, label) that represent one batch of per-replica training inputs after GraphTensor.merge_batch_to_components() has been applied.
Returns
A trained tf.keras.Model.