Skip to content
This repository was archived by the owner on Apr 3, 2025. It is now read-only.

Commit 15af2af

Browse files
authored
Merge pull request #59 from google/fix_local_runner
Fix DirectRunner
2 parents e00937d + 02135f4 commit 15af2af

File tree

4 files changed

+94
-13
lines changed

4 files changed

+94
-13
lines changed

megalista_dataflow/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import warnings
1717

1818
import apache_beam as beam
19+
from apache_beam import coders
1920
from apache_beam.options.pipeline_options import PipelineOptions
2021
from mappers.ads_user_list_pii_hashing_mapper import \
2122
AdsUserListPIIHashingMapper
@@ -24,7 +25,7 @@
2425
from models.oauth_credentials import OAuthCredentials
2526
from models.options import DataflowOptions
2627
from models.sheets_config import SheetsConfig
27-
from sources.batches_from_executions import BatchesFromExecutions
28+
from sources.batches_from_executions import BatchesFromExecutions, ExecutionCoder
2829
from sources.primary_execution_source import PrimaryExecutionSource
2930
from uploaders.big_query.transactional_events_results_writer import TransactionalEventsResultsWriter
3031
from uploaders.campaign_manager.campaign_manager_conversion_uploader import CampaignManagerConversionUploaderDoFn
@@ -272,6 +273,8 @@ def run(argv=None):
272273

273274
params = MegalistaStepParams(oauth_credentials, dataflow_options)
274275

276+
coders.registry.register_coder(Execution, ExecutionCoder)
277+
275278
with beam.Pipeline(options=pipeline_options) as pipeline:
276279
executions = pipeline | "Load executions" >> beam.io.Read(execution_source)
277280

megalista_dataflow/models/execution.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,25 @@ def campaign_manager_account_id(self) -> str:
8080
def app_id(self) -> str:
8181
return self._app_id
8282

83+
def to_dict(self):
84+
return {
85+
'google_ads_account_id' : self.google_ads_account_id,
86+
'mcc': self.mcc,
87+
'google_analytics_account_id': self.google_analytics_account_id,
88+
'campaign_manager_account_id': self.campaign_manager_account_id,
89+
'app_id': self.app_id,
90+
}
91+
92+
@staticmethod
93+
def from_dict(dict_account_config):
94+
return AccountConfig(
95+
dict_account_config['google_ads_account_id'],
96+
dict_account_config['mcc'],
97+
dict_account_config['google_analytics_account_id'],
98+
dict_account_config['campaign_manager_account_id'],
99+
dict_account_config['app_id'],
100+
)
101+
83102
def __str__(self) -> str:
84103
return (
85104
f"\n[Account Config]\n\t"
@@ -129,6 +148,21 @@ def source_type(self) -> SourceType:
129148
def source_metadata(self) -> List[str]:
130149
return self._source_metadata
131150

151+
def to_dict(self):
152+
return {
153+
'source_name': self.source_name,
154+
'source_type' : self.source_type.name,
155+
'source_metadata': self.source_metadata,
156+
}
157+
158+
@staticmethod
159+
def from_dict(dict_source):
160+
return Source(
161+
dict_source['source_name'],
162+
SourceType[dict_source['source_type']],
163+
dict_source['source_metadata']
164+
)
165+
132166
def __eq__(self, other):
133167
return (
134168
self.source_name == other.source_name
@@ -170,6 +204,21 @@ def destination_type(self) -> DestinationType:
170204
def destination_metadata(self) -> List[str]:
171205
return self._destination_metadata
172206

207+
def to_dict(self):
208+
return {
209+
'destination_name': self.destination_name,
210+
'destination_type': self.destination_type.name,
211+
'destination_metadata': self.destination_metadata,
212+
}
213+
214+
@staticmethod
215+
def from_dict(dict_destination):
216+
return Destination(
217+
dict_destination['destination_name'],
218+
DestinationType[dict_destination['destination_type']],
219+
dict_destination['destination_metadata'],
220+
)
221+
173222
def __eq__(self, other) -> bool:
174223
return bool(
175224
self.destination_name == other.destination_name
@@ -206,6 +255,21 @@ def destination(self) -> Destination:
206255
def account_config(self) -> AccountConfig:
207256
return self._account_config
208257

258+
def to_dict(self):
259+
return {
260+
'account_config': self.account_config.to_dict(),
261+
'source': self.source.to_dict(),
262+
'destination': self.destination.to_dict()
263+
}
264+
265+
@staticmethod
266+
def from_dict(dict_execution):
267+
return Execution(
268+
AccountConfig.from_dict(dict_execution['account_config']),
269+
Source.from_dict(dict_execution['source']),
270+
Destination.from_dict(dict_execution['destination']),
271+
)
272+
209273
def __str__(self):
210274
return f"Origin name: {self.source.source_name}. Action: {self.destination.destination_type}. Destination name: {self.destination.destination_name}"
211275

megalista_dataflow/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ protobuf==3.13.0
44
google-api-python-client==1.12.8
55
google-cloud-core==1.4.1
66
google-cloud-bigquery==1.27.2
7-
apache-beam[gcp]==2.34.0
8-
apache-beam==2.34.0
7+
apache-beam[gcp]==2.36.0
8+
apache-beam==2.36.0
99
google-cloud-datastore==1.13.1
1010
google-apitools==0.5.31
1111
pytest==5.4.3

megalista_dataflow/sources/batches_from_executions.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, List, Iterable
16-
1715
import apache_beam as beam
1816
import logging
17+
import json
1918

19+
from apache_beam.coders import coders
2020
from apache_beam.options.value_provider import ValueProvider
2121
from google.cloud import bigquery
22-
from apache_beam.io.gcp.bigquery import ReadFromBigQueryRequest
23-
2422
from models.execution import DestinationType, Execution, Batch
23+
from typing import Any, List, Iterable, Tuple, Dict
24+
2525

2626
_BIGQUERY_PAGE_SIZE = 20000
2727

@@ -35,29 +35,43 @@ def _convert_row_to_dict(row):
3535
return dict
3636

3737

38+
class ExecutionCoder(coders.Coder):
39+
"""A custom coder for the Execution class."""
40+
41+
def encode(self, o):
42+
return json.dumps(o.to_dict()).encode('utf-8')
43+
44+
def decode(self, s):
45+
return Execution.from_dict(json.loads(s.decode('utf-8')))
46+
47+
def is_deterministic(self):
48+
return True
49+
50+
3851
class BatchesFromExecutions(beam.PTransform):
3952
"""
4053
Filter the received executions by the received action,
4154
load the data using the received source and group by that batch size and Execution.
4255
"""
4356

4457
class _ExecutionIntoBigQueryRequest(beam.DoFn):
45-
def process(self, execution: Execution) -> Iterable[ReadFromBigQueryRequest]:
58+
59+
def process(self, execution: Execution) -> Iterable[Tuple[Execution, Dict[str, Any]]]:
4660
client = bigquery.Client()
4761
table_name = execution.source.source_metadata[0] + '.' + execution.source.source_metadata[1]
4862
table_name = table_name.replace('`', '')
4963
query = f"SELECT data.* FROM `{table_name}` AS data"
5064
logging.getLogger(_LOGGER_NAME).info(f'Reading from table {table_name} for Execution {execution}')
5165
rows_iterator = client.query(query).result(page_size=_BIGQUERY_PAGE_SIZE)
5266
for row in rows_iterator:
53-
yield {'execution': execution, 'row': _convert_row_to_dict(row)}
67+
yield execution, _convert_row_to_dict(row)
5468

5569
class _ExecutionIntoBigQueryRequestTransactional(beam.DoFn):
5670

5771
def __init__(self, bq_ops_dataset):
5872
self._bq_ops_dataset = bq_ops_dataset
5973

60-
def process(self, execution: Execution) -> Iterable[ReadFromBigQueryRequest]:
74+
def process(self, execution: Execution) -> Iterable[Tuple[Execution, Dict[str, Any]]]:
6175
table_name = execution.source.source_metadata[0] + \
6276
'.' + execution.source.source_metadata[1]
6377
table_name = table_name.replace('`', '')
@@ -86,7 +100,7 @@ def process(self, execution: Execution) -> Iterable[ReadFromBigQueryRequest]:
86100
f'Reading from table {table_name} for Execution {execution}')
87101
rows_iterator = client.query(query).result(page_size=_BIGQUERY_PAGE_SIZE)
88102
for row in rows_iterator:
89-
yield {'execution': execution, 'row': _convert_row_to_dict(row)}
103+
yield execution, _convert_row_to_dict(row)
90104

91105

92106
class _BatchElements(beam.DoFn):
@@ -102,7 +116,7 @@ def process(self, grouped_elements):
102116
if i != 0 and i % self._batch_size == 0:
103117
yield Batch(execution, batch)
104118
batch = []
105-
batch.append(element['row'])
119+
batch.append(element)
106120
yield Batch(execution, batch)
107121

108122
def __init__(
@@ -131,6 +145,6 @@ def expand(self, executions):
131145
executions
132146
| beam.Filter(lambda execution: execution.destination.destination_type == self._destination_type)
133147
| beam.ParDo(self._get_bq_request_class())
134-
| beam.GroupBy(lambda x: x['execution'])
148+
| beam.GroupByKey()
135149
| beam.ParDo(self._BatchElements(self._batch_size))
136150
)

0 commit comments

Comments
 (0)