Skip to content

Commit b712215

Browse files
committed
Merge pull request tensorflow#2042 from 18jeffreyma:master
PiperOrigin-RevId: 319106358
2 parents d16afaa + 8a87f4b commit b712215

File tree

18 files changed

+152
-23
lines changed

18 files changed

+152
-23
lines changed

RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* Added the ConcatPlaceholder to tfx.dsl.component.experimental.placeholders.
77
* Changed Span information as a property of ExampleGen's output artifact.
88
Deprecated ExampleGen input (external) artifact.
9+
* Added ModelRun artifact for Trainer for storing training related files,
10+
e.g., Tensorboard logs.
911

1012
## Bug fixes and other changes
1113
* Added Tuner component, which is still work in progress.

docs/tutorials/tfx/components_keras.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1284,7 +1284,7 @@
12841284
" for i in range(num_dnn_layers)\n",
12851285
" ])\n",
12861286
"\n",
1287-
" # This log path might change in the future.\n",
1287+
" # TODO(b/158106209): use ModelRun instead of Model artifact for logging.\n",
12881288
" log_dir = os.path.join(os.path.dirname(fn_args.serving_model_dir), 'logs')\n",
12891289
" tensorboard_callback = tf.keras.callbacks.TensorBoard(\n",
12901290
" log_dir=log_dir, update_freq='batch')\n",

tfx/components/testdata/module_file/trainer_module.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from __future__ import division
2626
from __future__ import print_function
2727

28+
import os
2829
import absl
2930
import tensorflow as tf
3031
import tensorflow_model_analysis as tfma
@@ -340,4 +341,8 @@ def run_fn(fn_args: executor.TrainerFnArgs):
340341
export_dir_base=fn_args.eval_model_dir,
341342
eval_input_receiver_fn=training_spec['eval_input_receiver_fn'])
342343

344+
# Simulate writing a log to the path given by fn_args
345+
io_utils.write_string_file(
346+
os.path.join(fn_args.model_run_dir, 'fake_log.txt'), '')
347+
343348
absl.logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir)

tfx/components/trainer/component.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
custom_config: Optional[Dict[Text, Any]] = None,
120120
custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
121121
output: Optional[types.Channel] = None,
122+
model_run: Optional[types.Channel] = None,
122123
transform_output: Optional[types.Channel] = None,
123124
instance_name: Optional[Text] = None):
124125
"""Construct a Trainer component.
@@ -179,6 +180,8 @@ def trainer_fn(trainer.executor.TrainerFnArgs,
179180
that will be passed into user module.
180181
custom_executor_spec: Optional custom executor spec.
181182
output: Optional `Model` channel for result of exported models.
183+
model_run: Optional `ModelRun` channel, as the working dir of models,
184+
can be used to output non-model related output (e.g., TensorBoard logs).
182185
transform_output: Backwards compatibility alias for the 'transform_graph'
183186
argument.
184187
instance_name: Optional unique instance name. Necessary iff multiple
@@ -214,6 +217,9 @@ def trainer_fn(trainer.executor.TrainerFnArgs,
214217
examples = examples or transformed_examples
215218
output = output or types.Channel(
216219
type=standard_artifacts.Model, artifacts=[standard_artifacts.Model()])
220+
model_run = model_run or types.Channel(
221+
type=standard_artifacts.ModelRun,
222+
artifacts=[standard_artifacts.ModelRun()])
217223
spec = TrainerSpec(
218224
examples=examples,
219225
transform_graph=transform_graph,
@@ -226,7 +232,9 @@ def trainer_fn(trainer.executor.TrainerFnArgs,
226232
run_fn=run_fn,
227233
trainer_fn=trainer_fn,
228234
custom_config=json_utils.dumps(custom_config),
229-
model=output)
235+
model=output,
236+
# TODO(b/158106209): change the model_run as optional output artifact
237+
model_run=model_run)
230238
super(Trainer, self).__init__(
231239
spec=spec,
232240
custom_executor_spec=custom_executor_spec,

tfx/components/trainer/component_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def setUp(self):
4646
def _verify_outputs(self, trainer):
4747
self.assertEqual(standard_artifacts.Model.TYPE_NAME,
4848
trainer.outputs['model'].type_name)
49+
self.assertEqual(standard_artifacts.ModelRun.TYPE_NAME,
50+
trainer.outputs['model_run'].type_name)
4951

5052
def testConstructFromModuleFile(self):
5153
module_file = '/path/to/module/file'

tfx/components/trainer/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
CUSTOM_CONFIG_KEY = 'custom_config'
3838

3939
# Key for output model in executor output_dict.
40-
OUTPUT_MODEL_KEY = 'model'
40+
MODEL_KEY = 'model'
41+
# Key for log output in executor output_dict
42+
MODEL_RUN_KEY = 'model_run'
4143

4244
# The name of environment variable to indicate distributed training cluster.
4345
TF_CONFIG_ENV = 'TF_CONFIG'

tfx/components/trainer/executor.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,13 @@ def _GetFnArgs(self, input_dict: Dict[Text, List[types.Artifact]],
130130
hyperparameters_config = None
131131

132132
output_path = artifact_utils.get_single_uri(
133-
output_dict[constants.OUTPUT_MODEL_KEY])
133+
output_dict[constants.MODEL_KEY])
134134
serving_model_dir = path_utils.serving_model_dir(output_path)
135135
eval_model_dir = path_utils.eval_model_dir(output_path)
136136

137+
model_run_dir = artifact_utils.get_single_uri(
138+
output_dict[constants.MODEL_RUN_KEY])
139+
137140
# TODO(b/126242806) Use PipelineInputs when it is available in third_party.
138141
return TrainerFnArgs(
139142
# A list of uris for train files.
@@ -148,6 +151,8 @@ def _GetFnArgs(self, input_dict: Dict[Text, List[types.Artifact]],
148151
eval_model_dir=eval_model_dir,
149152
# A list of uris for eval files.
150153
eval_files=fn_args.eval_files,
154+
# A single uri for the output directory of model training related files.
155+
model_run_dir=model_run_dir,
151156
# A single uri for schema file.
152157
schema_file=fn_args.schema_path,
153158
# Number of train steps.
@@ -168,7 +173,8 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
168173
169174
The Trainer Executor invokes a run_fn callback function provided by
170175
the user via the module_file parameter. In this function, user defines the
171-
model and train it, then save the model to the provided location.
176+
model and trains it, then saves the model and training related files
177+
(e.g, Tensorboard logs) to the provided locations.
172178
173179
Args:
174180
input_dict: Input dict from input key to a list of ML-Metadata Artifacts.
@@ -177,7 +183,8 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
177183
- transform_output: Optional input transform graph.
178184
- schema: Schema of the data.
179185
output_dict: Output dict from output key to a list of Artifacts.
180-
- output: Exported model.
186+
- model: Exported model.
187+
- model_run: Model training related outputs (e.g., Tensorboard logs)
181188
exec_properties: A dict of execution properties.
182189
- train_args: JSON string of trainer_pb2.TrainArgs instance, providing
183190
args for training.
@@ -211,8 +218,10 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
211218
# module's responsibility to export the model only once.
212219
if not tf.io.gfile.exists(fn_args.serving_model_dir):
213220
raise RuntimeError('run_fn failed to generate model.')
214-
absl.logging.info('Training complete. Model written to %s',
215-
fn_args.serving_model_dir)
221+
222+
absl.logging.info(
223+
'Training complete. Model written to %s. ModelRun written to %s',
224+
fn_args.serving_model_dir, fn_args.model_run_dir)
216225

217226

218227
class Executor(GenericExecutor):
@@ -244,7 +253,8 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
244253
- transform_output: Optional input transform graph.
245254
- schema: Schema of the data.
246255
output_dict: Output dict from output key to a list of Artifacts.
247-
- output: Exported model.
256+
- model: Exported model.
257+
- model_run: Model training related outputs (e.g., Tensorboard logs)
248258
exec_properties: A dict of execution properties.
249259
- train_args: JSON string of trainer_pb2.TrainArgs instance, providing
250260
args for training.
@@ -278,8 +288,10 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
278288
tf.estimator.train_and_evaluate(training_spec['estimator'],
279289
training_spec['train_spec'],
280290
training_spec['eval_spec'])
281-
absl.logging.info('Training complete. Model written to %s',
282-
fn_args.serving_model_dir)
291+
292+
absl.logging.info(
293+
'Training complete. Model written to %s. ModelRun written to %s',
294+
fn_args.serving_model_dir, fn_args.model_run_dir)
283295

284296
# Export an eval savedmodel for TFMA. If distributed training, it must only
285297
# be written by the chief worker, as would be done for serving savedmodel.
@@ -290,6 +302,10 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
290302
export_dir_base=fn_args.eval_model_dir,
291303
eval_input_receiver_fn=training_spec['eval_input_receiver_fn'])
292304

305+
# TODO(b/158106209): refactor serving_model_dir to only contain model.
306+
# Copy model run information to ModelRun artifact
307+
io_utils.copy_dir(fn_args.serving_model_dir, fn_args.model_run_dir)
308+
293309
absl.logging.info('Exported eval_savedmodel to %s.',
294310
fn_args.eval_model_dir)
295311
else:

tfx/components/trainer/executor_test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,13 @@ def setUp(self):
7070
self._model_exports = standard_artifacts.Model()
7171
self._model_exports.uri = os.path.join(self._output_data_dir,
7272
'model_export_path')
73-
self._output_dict = {constants.OUTPUT_MODEL_KEY: [self._model_exports]}
73+
self._model_run_exports = standard_artifacts.ModelRun()
74+
self._model_run_exports.uri = os.path.join(self._output_data_dir,
75+
'model_run_path')
76+
self._output_dict = {
77+
constants.MODEL_KEY: [self._model_exports],
78+
constants.MODEL_RUN_KEY: [self._model_run_exports]
79+
}
7480

7581
# Create exec properties skeleton.
7682
self._exec_properties = {
@@ -106,6 +112,10 @@ def _verify_no_eval_model_exports(self):
106112
self.assertFalse(
107113
tf.io.gfile.exists(path_utils.eval_model_dir(self._model_exports.uri)))
108114

115+
def _verify_model_run_exports(self):
116+
self.assertTrue(
117+
tf.io.gfile.exists(os.path.dirname(self._model_run_exports.uri)))
118+
109119
def _do(self, test_executor):
110120
test_executor.Do(
111121
input_dict=self._input_dict,
@@ -116,30 +126,35 @@ def testGenericExecutor(self):
116126
self._exec_properties['module_file'] = self._module_file
117127
self._do(self._generic_trainer_executor)
118128
self._verify_model_exports()
129+
self._verify_model_run_exports()
119130

120131
@mock.patch('tfx.components.trainer.executor._is_chief')
121132
def testDoChief(self, mock_is_chief):
122133
mock_is_chief.return_value = True
123134
self._exec_properties['module_file'] = self._module_file
124135
self._do(self._trainer_executor)
125136
self._verify_model_exports()
137+
self._verify_model_run_exports()
126138

127139
@mock.patch('tfx.components.trainer.executor._is_chief')
128140
def testDoNonChief(self, mock_is_chief):
129141
mock_is_chief.return_value = False
130142
self._exec_properties['module_file'] = self._module_file
131143
self._do(self._trainer_executor)
132144
self._verify_no_eval_model_exports()
145+
self._verify_model_run_exports()
133146

134147
def testDoWithModuleFile(self):
135148
self._exec_properties['module_file'] = self._module_file
136149
self._do(self._trainer_executor)
137150
self._verify_model_exports()
151+
self._verify_model_run_exports()
138152

139153
def testDoWithTrainerFn(self):
140154
self._exec_properties['trainer_fn'] = self._trainer_fn
141155
self._do(self._trainer_executor)
142156
self._verify_model_exports()
157+
self._verify_model_run_exports()
143158

144159
def testDoWithNoTrainerFn(self):
145160
with self.assertRaises(ValueError):
@@ -169,6 +184,7 @@ def testDoWithHyperParameters(self):
169184
self._exec_properties['module_file'] = self._module_file
170185
self._do(self._trainer_executor)
171186
self._verify_model_exports()
187+
self._verify_model_run_exports()
172188

173189

174190
if __name__ == '__main__':

tfx/dsl/compiler/testdata/iris_pipeline_ir.pbtxt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,16 @@ nodes {
617617
}
618618
}
619619
}
620+
outputs {
621+
key: "model_run"
622+
value {
623+
artifact_spec {
624+
type {
625+
name: "ModelRun"
626+
}
627+
}
628+
}
629+
}
620630
}
621631
parameters {
622632
parameters {

tfx/examples/airflow_workshop/setup/dags/taxi_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,13 @@
322322
# for i in range(num_dnn_layers)
323323
# ])
324324
#
325-
# # TODO(b/158106209): This log path might change in the future.
326-
# log_dir = os.path.join(os.path.dirname(fn_args.serving_model_dir), 'logs')
325+
# try:
326+
# log_dir = fn_args.model_run_dir
327+
# except KeyError:
328+
# # TODO(b/158106209): use ModelRun instead of Model artifact for logging.
329+
# log_dir = os.path.join(os.path.dirname(fn_args.serving_model_dir), 'logs')
330+
#
331+
# # Write logs to path
327332
# tensorboard_callback = tf.keras.callbacks.TensorBoard(
328333
# log_dir=log_dir, update_freq='batch')
329334
#

0 commit comments

Comments
 (0)