|
1 | 1 | import logging |
2 | 2 | import os |
3 | 3 |
|
4 | | -from typing import Any, Optional |
| 4 | +from typing import Any, List, Optional |
5 | 5 |
|
6 | 6 | from airflow.exceptions import AirflowSkipException |
7 | 7 | from airflow.models import BaseOperator |
@@ -172,9 +172,6 @@ def execute(self, context): |
172 | 172 | # Force potential string columns into lists for zipping in execute. |
173 | 173 | if isinstance(self.resource, str): |
174 | 174 | raise ValueError("Bulk operators require lists of resources to be passed.") |
175 | | - |
176 | | - if isinstance(self.table_name, str): |
177 | | - self.table_name = [self.table_name] * len(self.resource) |
178 | 175 |
|
179 | 176 | ### Optionally set destination key by concatting separate args for dir and filename |
180 | 177 | if not self.s3_destination_key: |
@@ -204,16 +201,83 @@ def execute(self, context): |
204 | 201 | # Build and run the SQL queries to Snowflake. Delete first if EdFi2 or a full-refresh. |
205 | 202 | xcom_returns = [] |
206 | 203 |
|
207 | | - for idx, (resource, table, s3_destination_key) in enumerate(zip(self.resource, self.table_name, self.s3_destination_key), start=1): |
208 | | - logging.info(f"[ENDPOINT {idx} / {len(self.resource)}]") |
209 | | - self.run_sql_queries( |
210 | | - name=resource, table=table, |
211 | | - s3_key=s3_destination_key, full_refresh=airflow_util.is_full_refresh(context) |
| 204 | + # If all data is sent to the same table, use a single massive SQL query to copy the data from the directory. |
| 205 | + if isinstance(self.table_name, str): |
| 206 | + logging.info("Running bulk statements on a single table.") |
| 207 | + self.run_bulk_sql_queries( |
| 208 | + names=self.resource, table=self.table_name, |
| 209 | + s3_dir=self.s3_destination_dir or os.path.dirname(self.s3_destination_key[0]), # Infer directory if not specified. |
| 210 | + full_refresh=airflow_util.is_full_refresh(context) |
212 | 211 | ) |
| 212 | + |
| 213 | + # Otherwise, loop over each S3 destination and copy in sequence. |
| 214 | + else: |
| 215 | + for idx, (resource, table, s3_destination_key) in enumerate(zip(self.resource, self.table_name, self.s3_destination_key), start=1): |
| 216 | + logging.info(f"[ENDPOINT {idx} / {len(self.resource)}]") |
| 217 | + self.run_sql_queries( |
| 218 | + name=resource, table=table, |
| 219 | + s3_key=s3_destination_key, full_refresh=airflow_util.is_full_refresh(context) |
| 220 | + ) |
213 | 221 |
|
214 | 222 | # Send the prebuilt-output if specified; otherwise, send the compiled list created above. |
215 | 223 | # This only exists to maintain backwards-compatibility with original S3ToSnowflakeOperator. |
216 | 224 | if self.xcom_return: |
217 | 225 | return self.xcom_return |
218 | 226 | else: |
219 | 227 | return xcom_returns |
| 228 | + |
| 229 | + def run_bulk_sql_queries(self, names: List[str], table: str, s3_dir: str, full_refresh: bool = False): |
| 230 | + """ |
| 231 | + Alternative delete and copy queries to be run when all data is sent to the same table in Snowflake. |
| 232 | + |
| 233 | + S3 Path Structure: |
| 234 | + /{tenant_code}/{api_year}/{ds_nodash}/{ts_no_dash}/{taskgroup_type}/{name}.jsonl |
| 235 | +
|
| 236 | + Use regex to capture name: ".+/(\\w+).jsonl?" |
| 237 | + Note optional args in REGEXP_SUBSTR(): position (1), occurrence (1), regex_parameters ('c'), group_num |
| 238 | + """ |
| 239 | + ### Retrieve the database and schema from the Snowflake hook. |
| 240 | + snowflake_hook = SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id) |
| 241 | + database, schema = airflow_util.get_snowflake_params_from_conn(self.snowflake_conn_id) |
| 242 | + |
| 243 | + ### Build the SQL queries to be passed into `Hook.run()`. |
| 244 | + # Brackets in regex conflict with string formatting. |
| 245 | + date_regex = "\\\\d{8}" |
| 246 | + ts_regex = "\\\\d{8}T\\\\d{6}" |
| 247 | + |
| 248 | + qry_copy_into = f""" |
| 249 | + COPY INTO {database}.{schema}.{table} |
| 250 | + (tenant_code, api_year, pull_date, pull_timestamp, file_row_number, filename, name, ods_version, data_model_version, v) |
| 251 | + FROM ( |
| 252 | + SELECT |
| 253 | + '{self.tenant_code}' AS tenant_code, |
| 254 | + '{self.api_year}' AS api_year, |
| 255 | + TO_DATE(REGEXP_SUBSTR(metadata$filename, '{date_regex}'), 'YYYYMMDD') AS pull_date, |
| 256 | + TO_TIMESTAMP(REGEXP_SUBSTR(metadata$filename, '{ts_regex}'), 'YYYYMMDDTHH24MISS') AS pull_timestamp, |
| 257 | + metadata$file_row_number AS file_row_number, |
| 258 | + metadata$filename AS filename, |
| 259 | + REGEXP_SUBSTR(filename, '.+/(\\\\w+).jsonl?', 1, 1, 'c', 1) AS name, |
| 260 | + '{self.ods_version}' AS ods_version, |
| 261 | + '{self.data_model_version}' AS data_model_version, |
| 262 | + t.$1 AS v |
| 263 | + FROM '@{database}.util.airflow_stage/{s3_dir}' |
| 264 | + (file_format => 'json_default') t |
| 265 | + ) |
| 266 | + force = true; |
| 267 | + """ |
| 268 | + |
| 269 | + ### Commit the update queries to Snowflake. |
| 270 | + # Incremental runs are only available in EdFi 3+. |
| 271 | + if self.full_refresh or full_refresh: |
| 272 | + names_string = "', '".join(names) |
| 273 | + |
| 274 | + qry_delete = f""" |
| 275 | + DELETE FROM {database}.{schema}.{table} |
| 276 | + WHERE tenant_code = '{self.tenant_code}' |
| 277 | + AND api_year = '{self.api_year}' |
| 278 | + AND name in ('{names_string}') |
| 279 | + """ |
| 280 | + snowflake_hook.run(sql=[qry_delete, qry_copy_into], autocommit=False) |
| 281 | + |
| 282 | + else: |
| 283 | + snowflake_hook.run(sql=qry_copy_into) |
0 commit comments