40
40
DEFAULT_FILE_NAME = 'data_tfrecord'
41
41
42
42
43
- def _ExamplePartitionKey (record : tf .train .Example ,
44
- split_config : example_gen_pb2 .SplitConfig ) -> bytes :
45
- """Generates key for partition for tf.train.Example."""
43
+ def _GeneratePartitionKey (record : Union [tf .train .Example ,
44
+ tf .train .SequenceExample , bytes ],
45
+ split_config : example_gen_pb2 .SplitConfig ) -> bytes :
46
+ """Generates key for partition."""
46
47
47
48
if not split_config .HasField ('partition_feature_name' ):
49
+ if isinstance (record , bytes ):
50
+ return record
48
51
return record .SerializeToString (deterministic = True )
49
52
53
+ if isinstance (record , tf .train .Example ):
54
+ features = record .features .feature # pytype: disable=attribute-error
55
+ elif isinstance (record , tf .train .SequenceExample ):
56
+ features = record .context .feature # pytype: disable=attribute-error
57
+ else :
58
+ raise RuntimeError ('Split by `partition_feature_name` is only supported '
59
+ 'for FORMAT_TF_EXAMPLE and FORMAT_TF_SEQUENCE_EXAMPLE '
60
+ 'payload format.' )
61
+
50
62
# Use a feature for partitioning the examples.
51
63
feature_name = split_config .partition_feature_name
52
- if feature_name not in record . features . feature :
64
+ if feature_name not in features :
53
65
raise RuntimeError ('Feature name `{}` does not exist.' .format (feature_name ))
54
- feature = record . features . feature [feature_name ]
66
+ feature = features [feature_name ]
55
67
if not feature .HasField ('kind' ):
56
68
raise RuntimeError ('Partition feature does not contain any value.' )
57
69
if (not feature .HasField ('bytes_list' ) and
@@ -62,23 +74,15 @@ def _ExamplePartitionKey(record: tf.train.Example,
62
74
63
75
64
76
def _PartitionFn (
65
- record : Union [tf .train .Example , bytes ],
77
+ record : Union [tf .train .Example , tf . train . SequenceExample , bytes ],
66
78
num_partitions : int ,
67
79
buckets : List [int ],
68
80
split_config : example_gen_pb2 .SplitConfig ,
69
81
) -> int :
70
82
"""Partition function for the ExampleGen's output splits."""
71
83
assert num_partitions == len (
72
84
buckets ), 'Partitions do not match bucket number.'
73
-
74
- if isinstance (record , tf .train .Example ):
75
- partition_str = _ExamplePartitionKey (record , split_config )
76
- elif split_config .HasField ('partition_feature_name' ):
77
- raise RuntimeError ('Split by `partition_feature_name` is only supported '
78
- 'for FORMAT_TF_EXAMPLE payload format.' )
79
- else :
80
- partition_str = record
81
-
85
+ partition_str = _GeneratePartitionKey (record , split_config )
82
86
bucket = int (hashlib .sha256 (partition_str ).hexdigest (), 16 ) % buckets [- 1 ]
83
87
# For example, if buckets is [10,50,80], there will be 3 splits:
84
88
# bucket >=0 && < 10, returns 0
@@ -88,14 +92,17 @@ def _PartitionFn(
88
92
89
93
90
94
@beam .ptransform_fn
91
- @beam .typehints .with_input_types (Union [tf .train .Example , bytes ])
95
+ @beam .typehints .with_input_types (Union [tf .train .Example ,
96
+ tf .train .SequenceExample , bytes ])
92
97
@beam .typehints .with_output_types (beam .pvalue .PDone )
93
98
def _WriteSplit (example_split : beam .pvalue .PCollection ,
94
99
output_split_path : Text ) -> beam .pvalue .PDone :
95
100
"""Shuffles and writes output split as serialized records in TFRecord."""
96
101
97
102
def _MaybeSerialize (x ):
98
- return x .SerializeToString () if isinstance (x , tf .train .Example ) else x
103
+ if isinstance (x , (tf .train .Example , tf .train .SequenceExample )):
104
+ return x .SerializeToString ()
105
+ return x
99
106
100
107
return (example_split
101
108
# TODO(jyzhao): make shuffle optional.
@@ -107,28 +114,13 @@ def _MaybeSerialize(x):
107
114
file_name_suffix = '.gz' ))
108
115
109
116
110
- @beam .ptransform_fn
111
- @beam .typehints .with_input_types (beam .Pipeline )
112
- @beam .typehints .with_output_types (Union [tf .train .Example , bytes ])
113
- def _InputToExampleOrBytes (
114
- pipeline : beam .Pipeline ,
115
- input_to_example : beam .PTransform ,
116
- exec_properties : Dict [Text , Any ],
117
- split_pattern : Text ,
118
- ) -> beam .pvalue .PCollection :
119
- """Converts input into a tf.train.Example, or a bytes (serialized proto)."""
120
- return (pipeline
121
- | 'InputSourceToExampleOrBytes' >> input_to_example (
122
- exec_properties , split_pattern ))
123
-
124
-
125
117
class BaseExampleGenExecutor (
126
118
with_metaclass (abc .ABCMeta , base_executor .BaseExecutor )):
127
119
"""Generic TFX example gen base executor.
128
120
129
121
The base ExampleGen executor takes a configuration and converts external data
130
- sources to TensorFlow Examples (tf.train.Example), or any other protocol
131
- buffer as subclass defines.
122
+ sources to TensorFlow Examples (tf.train.Example, tf.train.SequenceExample),
123
+ or any other protocol buffer as subclass defines.
132
124
133
125
The common configuration (defined in
134
126
https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto#L44.)
@@ -137,12 +129,14 @@ class BaseExampleGenExecutor(
137
129
138
130
The conversion is done in `GenerateExamplesByBeam` as a Beam pipeline, which
139
131
validates the configuration, reads the external data sources, converts the
140
- record in the input source to tf.Example if needed, and splits the examples if
141
- the output split config is given. Then the executor's `Do` writes the results
142
- in splits to the output path.
132
+ record in the input source to any supported output payload formats
133
+ (e.g., tf.Example or tf.SequenceExample) if needed, and splits the examples
134
+ if the output split config is given. Then the executor's `Do` writes the
135
+ results in splits to the output path.
143
136
144
137
For simple custom ExampleGens, the details of transforming input data
145
- record(s) to a tf.Example is expected to be given in
138
+ record(s) to a specific output payload format (e.g., tf.Example or
139
+ tf.SequenceExample) is expected to be given in
146
140
`GetInputSourceToExamplePTransform`, which returns a Beam PTransform with the
147
141
actual implementation. For complex use cases, such as joining multiple data
148
142
sources and different interpretations of the configurations, the custom
@@ -163,7 +157,9 @@ def GetInputSourceToExamplePTransform(self) -> beam.PTransform:
163
157
Here is an example PTransform:
164
158
@beam.ptransform_fn
165
159
@beam.typehints.with_input_types(beam.Pipeline)
166
- @beam.typehints.with_output_types(Union[tf.train.Example, bytes])
160
+ @beam.typehints.with_output_types(Union[tf.train.Example,
161
+ tf.train.SequenceExample,
162
+ bytes])
167
163
def ExamplePTransform(
168
164
pipeline: beam.Pipeline,
169
165
exec_properties: Dict[Text, Any],
@@ -176,15 +172,15 @@ def GenerateExamplesByBeam(
176
172
pipeline : beam .Pipeline ,
177
173
exec_properties : Dict [Text , Any ],
178
174
) -> Dict [Text , beam .pvalue .PCollection ]:
179
- """Converts input source to TF example splits based on configs.
175
+ """Converts input source to serialized record splits based on configs.
180
176
181
177
Custom ExampleGen executor should provide GetInputSourceToExamplePTransform
182
- for converting input split to TF Examples . Overriding this
178
+ for converting input split to serialized records . Overriding this
183
179
'GenerateExamplesByBeam' method instead if complex logic is need, e.g.,
184
180
custom spliting logic.
185
181
186
182
Args:
187
- pipeline: beam pipeline.
183
+ pipeline: Beam pipeline.
188
184
exec_properties: A dict of execution properties. Depends on detailed
189
185
example gen implementation.
190
186
- input_base: an external directory containing the data files.
@@ -197,7 +193,7 @@ def GenerateExamplesByBeam(
197
193
198
194
Returns:
199
195
Dict of beam PCollection with split name as key, each PCollection is a
200
- single output split that contains serialized TF Examples .
196
+ single output split that contains serialized records .
201
197
"""
202
198
# Get input split information.
203
199
input_config = example_gen_pb2 .Input ()
@@ -214,7 +210,7 @@ def GenerateExamplesByBeam(
214
210
exec_properties ['_beam_pipeline_args' ] = self ._beam_pipeline_args or []
215
211
216
212
example_splits = []
217
- input_to_example = self .GetInputSourceToExamplePTransform ()
213
+ input_to_record = self .GetInputSourceToExamplePTransform ()
218
214
if output_config .split_config .splits :
219
215
# Use output splits, input must have only one split.
220
216
assert len (
@@ -228,21 +224,19 @@ def GenerateExamplesByBeam(
228
224
buckets .append (total_buckets )
229
225
example_splits = (
230
226
pipeline
231
- | 'InputToExampleOrBytes ' >>
227
+ | 'InputToRecord ' >>
232
228
# pylint: disable=no-value-for-parameter
233
- _InputToExampleOrBytes (input_to_example , exec_properties ,
234
- input_config .splits [0 ].pattern )
229
+ input_to_record (exec_properties , input_config .splits [0 ].pattern )
235
230
| 'SplitData' >> beam .Partition (_PartitionFn , len (buckets ), buckets ,
236
231
output_config .split_config ))
237
232
else :
238
233
# Use input splits.
239
234
for split in input_config .splits :
240
235
examples = (
241
236
pipeline
242
- | 'InputToExampleOrBytes [{}]' .format (split .name ) >>
237
+ | 'InputToRecord [{}]' .format (split .name ) >>
243
238
# pylint: disable=no-value-for-parameter
244
- _InputToExampleOrBytes (input_to_example , exec_properties ,
245
- split .pattern ))
239
+ input_to_record (exec_properties , split .pattern ))
246
240
example_splits .append (examples )
247
241
248
242
result = {}
@@ -258,22 +252,23 @@ def Do(
258
252
) -> None :
259
253
"""Take input data source and generates serialized data splits.
260
254
261
- The output is intended to be serialized tf.train.Examples protocol buffer
262
- in gzipped TFRecord format, but subclasses can choose to override to write
263
- to any serialized records payload into gzipped TFRecord as specified,
264
- so long as downstream component can consume it. The format of payload is
265
- added to `payload_format` custom property of the output Example artifact.
255
+ The output is intended to be serialized tf.train.Examples or
256
+ tf.train.SequenceExamples protocol buffer in gzipped TFRecord format,
257
+ but subclasses can choose to override to write to any serialized records
258
+ payload into gzipped TFRecord as specified, so long as downstream
259
+ component can consume it. The format of payload is added to
260
+ `payload_format` custom property of the output Example artifact.
266
261
267
262
Args:
268
263
input_dict: Input dict from input key to a list of Artifacts. Depends on
269
264
detailed example gen implementation.
270
265
output_dict: Output dict from output key to a list of Artifacts.
271
- - examples: splits of tf examples .
266
+ - examples: splits of serialized records .
272
267
exec_properties: A dict of execution properties. Depends on detailed
273
268
example gen implementation.
274
269
- input_base: an external directory containing the data files.
275
- - input_config: JSON string of example_gen_pb2.Input instance, providing
276
- input configuration.
270
+ - input_config: JSON string of example_gen_pb2.Input instance,
271
+ providing input configuration.
277
272
- output_config: JSON string of example_gen_pb2.Output instance,
278
273
providing output configuration.
279
274
- output_data_format: Payload format of generated data in output
0 commit comments