-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-54392][SS] Optimize JVM-Python communication for TWS initial state #53122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
35dd2b9
af121b2
bdf52de
ab2d6c0
64dd204
fcde4f9
621c23d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1466,7 +1466,9 @@ def __init__( | |
| self.result_state_pdf_arrow_type = to_arrow_type( | ||
| self.result_state_df_type, prefers_large_types=prefers_large_var_types | ||
| ) | ||
| self.arrow_max_records_per_batch = arrow_max_records_per_batch | ||
| self.arrow_max_records_per_batch = ( | ||
| arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 | ||
| ) | ||
|
|
||
| def load_stream(self, stream): | ||
| """ | ||
|
|
@@ -1821,13 +1823,29 @@ def __init__( | |
| int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, | ||
| arrow_cast=True, | ||
| ) | ||
| self.arrow_max_records_per_batch = arrow_max_records_per_batch | ||
| self.arrow_max_records_per_batch = ( | ||
| arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 | ||
| ) | ||
| self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch | ||
| self.key_offsets = None | ||
| self.average_arrow_row_size = 0 | ||
| self.total_bytes = 0 | ||
| self.total_rows = 0 | ||
|
|
||
| def _update_batch_size_stats(self, batch): | ||
| """ | ||
| Update batch size statistics for adaptive batching. | ||
| """ | ||
| # Short circuit batch size calculation if the batch size is | ||
| # unlimited as computing batch size is computationally expensive. | ||
| if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0: | ||
| batch_bytes = sum( | ||
| buf.size for col in batch.columns for buf in col.buffers() if buf is not None | ||
| ) | ||
| self.total_bytes += batch_bytes | ||
| self.total_rows += batch.num_rows | ||
| self.average_arrow_row_size = self.total_bytes / self.total_rows | ||
|
|
||
| def load_stream(self, stream): | ||
| """ | ||
| Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunk, and | ||
|
|
@@ -1855,18 +1873,7 @@ def generate_data_batches(batches): | |
|
|
||
| def row_stream(): | ||
| for batch in batches: | ||
| # Short circuit batch size calculation if the batch size is | ||
| # unlimited as computing batch size is computationally expensive. | ||
| if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0: | ||
| batch_bytes = sum( | ||
| buf.size | ||
| for col in batch.columns | ||
| for buf in col.buffers() | ||
| if buf is not None | ||
| ) | ||
| self.total_bytes += batch_bytes | ||
| self.total_rows += batch.num_rows | ||
| self.average_arrow_row_size = self.total_bytes / self.total_rows | ||
| self._update_batch_size_stats(batch) | ||
| data_pandas = [ | ||
| self.arrow_to_pandas(c, i) | ||
| for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) | ||
|
|
@@ -1946,6 +1953,7 @@ def __init__( | |
|
|
||
| def load_stream(self, stream): | ||
| import pyarrow as pa | ||
| import pandas as pd | ||
| from pyspark.sql.streaming.stateful_processor_util import ( | ||
| TransformWithStateInPandasFuncMode, | ||
| ) | ||
|
|
@@ -1964,6 +1972,12 @@ def generate_data_batches(batches): | |
|
|
||
| def flatten_columns(cur_batch, col_name): | ||
| state_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) | ||
|
|
||
| # Check if the entire column is null | ||
| if state_column.null_count == len(state_column): | ||
| # Return empty table with no columns | ||
| return pa.Table.from_arrays([], names=[]) | ||
|
|
||
| state_field_names = [ | ||
| state_column.type[i].name for i in range(state_column.type.num_fields) | ||
| ] | ||
|
|
@@ -1981,30 +1995,67 @@ def flatten_columns(cur_batch, col_name): | |
| .add("inputData", dataSchema) | ||
| .add("initState", initStateSchema) | ||
| We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python | ||
| data generator. All rows in the same batch have the same grouping key. | ||
| data generator. Rows in the same batch may have different grouping keys, | ||
| but each batch will have either init_data or input_data, not mix. | ||
| """ | ||
| for batch in batches: | ||
| flatten_state_table = flatten_columns(batch, "inputData") | ||
| data_pandas = [ | ||
| self.arrow_to_pandas(c, i) | ||
| for i, c in enumerate(flatten_state_table.itercolumns()) | ||
| ] | ||
|
|
||
| flatten_init_table = flatten_columns(batch, "initState") | ||
| init_data_pandas = [ | ||
| self.arrow_to_pandas(c, i) | ||
| for i, c in enumerate(flatten_init_table.itercolumns()) | ||
| ] | ||
| key_series = [data_pandas[o] for o in self.key_offsets] | ||
| init_key_series = [init_data_pandas[o] for o in self.init_key_offsets] | ||
| def row_stream(): | ||
| for batch in batches: | ||
| self._update_batch_size_stats(batch) | ||
|
|
||
| if any(s.empty for s in key_series): | ||
| # If any row is empty, assign batch_key using init_key_series | ||
| batch_key = tuple(s[0] for s in init_key_series) | ||
| else: | ||
| # If all rows are non-empty, create batch_key from key_series | ||
| batch_key = tuple(s[0] for s in key_series) | ||
| yield (batch_key, data_pandas, init_data_pandas) | ||
| flatten_state_table = flatten_columns(batch, "inputData") | ||
| data_pandas = [ | ||
| self.arrow_to_pandas(c, i) | ||
| for i, c in enumerate(flatten_state_table.itercolumns()) | ||
| ] | ||
|
|
||
| if bool(data_pandas): | ||
| for row in pd.concat(data_pandas, axis=1).itertuples(index=False): | ||
| batch_key = tuple(row[s] for s in self.key_offsets) | ||
| yield (batch_key, row, None) | ||
| else: | ||
| flatten_init_table = flatten_columns(batch, "initState") | ||
| init_data_pandas = [ | ||
| self.arrow_to_pandas(c, i) | ||
| for i, c in enumerate(flatten_init_table.itercolumns()) | ||
| ] | ||
| if bool(init_data_pandas): | ||
| for row in pd.concat(init_data_pandas, axis=1).itertuples(index=False): | ||
| batch_key = tuple(row[s] for s in self.init_key_offsets) | ||
| yield (batch_key, None, row) | ||
|
|
||
| EMPTY_DATAFRAME = pd.DataFrame() | ||
| for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]): | ||
| rows = [] | ||
| init_state_rows = [] | ||
| for _, row, init_state_row in group_rows: | ||
| if row is not None: | ||
| rows.append(row) | ||
| if init_state_row is not None: | ||
| init_state_rows.append(init_state_row) | ||
|
|
||
| total_len = len(rows) + len(init_state_rows) | ||
| if ( | ||
| total_len >= self.arrow_max_records_per_batch | ||
| or total_len * self.average_arrow_row_size >= self.arrow_max_bytes_per_batch | ||
| ): | ||
| yield ( | ||
| batch_key, | ||
| pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(), | ||
| pd.DataFrame(init_state_rows) | ||
| if len(init_state_rows) > 0 | ||
| else EMPTY_DATAFRAME.copy(), | ||
| ) | ||
| rows = [] | ||
| init_state_rows = [] | ||
| if rows or init_state_rows: | ||
| yield ( | ||
| batch_key, | ||
| pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(), | ||
| pd.DataFrame(init_state_rows) | ||
| if len(init_state_rows) > 0 | ||
| else EMPTY_DATAFRAME.copy(), | ||
| ) | ||
|
|
||
| _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) | ||
| data_batches = generate_data_batches(_batches) | ||
|
|
@@ -2030,7 +2081,9 @@ class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer): | |
|
|
||
| def __init__(self, arrow_max_records_per_batch): | ||
| super(TransformWithStateInPySparkRowSerializer, self).__init__() | ||
| self.arrow_max_records_per_batch = arrow_max_records_per_batch | ||
| self.arrow_max_records_per_batch = ( | ||
| arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 | ||
| ) | ||
| self.key_offsets = None | ||
|
|
||
| def load_stream(self, stream): | ||
|
|
@@ -2122,13 +2175,13 @@ def __init__(self, arrow_max_records_per_batch): | |
| self.init_key_offsets = None | ||
|
|
||
| def load_stream(self, stream): | ||
| import itertools | ||
| import pyarrow as pa | ||
| from pyspark.sql.streaming.stateful_processor_util import ( | ||
| TransformWithStateInPandasFuncMode, | ||
| ) | ||
| from typing import Iterator, Any, Optional, Tuple | ||
|
|
||
| def generate_data_batches(batches): | ||
| def generate_data_batches(batches) -> Iterator[Tuple[Any, Optional[Any], Optional[Any]]]: | ||
| """ | ||
| Deserialize ArrowRecordBatches and return a generator of Row. | ||
| The deserialization logic assumes that Arrow RecordBatches contain the data with the | ||
|
|
@@ -2139,8 +2192,15 @@ def generate_data_batches(batches): | |
| into the data generator. | ||
| """ | ||
|
|
||
| def extract_rows(cur_batch, col_name, key_offsets): | ||
| def extract_rows( | ||
| cur_batch, col_name, key_offsets | ||
| ) -> Optional[Iterator[Tuple[Any, Any]]]: | ||
| data_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) | ||
|
|
||
| # Check if the entire column is null | ||
| if data_column.null_count == len(data_column): | ||
| return None | ||
|
|
||
| data_field_names = [ | ||
| data_column.type[i].name for i in range(data_column.type.num_fields) | ||
| ] | ||
|
|
@@ -2153,68 +2213,62 @@ def extract_rows(cur_batch, col_name, key_offsets): | |
| table = pa.Table.from_arrays(data_field_arrays, names=data_field_names) | ||
|
|
||
| if table.num_rows == 0: | ||
| return (None, iter([])) | ||
| else: | ||
| batch_key = tuple(table.column(o)[0].as_py() for o in key_offsets) | ||
| return None | ||
|
|
||
| rows = [] | ||
| def row_iterator(): | ||
| for row_idx in range(table.num_rows): | ||
| key = tuple(table.column(o)[row_idx].as_py() for o in key_offsets) | ||
| row = DataRow( | ||
| *(table.column(i)[row_idx].as_py() for i in range(table.num_columns)) | ||
| ) | ||
| rows.append(row) | ||
| yield (key, row) | ||
|
|
||
| return (batch_key, iter(rows)) | ||
| return row_iterator() | ||
|
|
||
| """ | ||
| The arrow batch is written in the schema: | ||
| schema: StructType = new StructType() | ||
| .add("inputData", dataSchema) | ||
| .add("initState", initStateSchema) | ||
| We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python | ||
| data generator. All rows in the same batch have the same grouping key. | ||
| data generator. Each batch will have either init_data or input_data, not mix. | ||
| """ | ||
| for batch in batches: | ||
| (input_batch_key, input_data_iter) = extract_rows( | ||
| batch, "inputData", self.key_offsets | ||
| ) | ||
| (init_batch_key, init_state_iter) = extract_rows( | ||
| batch, "initState", self.init_key_offsets | ||
| ) | ||
| # Detect which column has data - each batch contains only one type | ||
| input_result = extract_rows(batch, "inputData", self.key_offsets) | ||
|
|
||
| if input_batch_key is None: | ||
| batch_key = init_batch_key | ||
| if input_result is not None: | ||
| for key, input_data_row in input_result: | ||
| yield (key, input_data_row, None) | ||
| else: | ||
| batch_key = input_batch_key | ||
|
|
||
| for init_state_row in init_state_iter: | ||
| yield (batch_key, None, init_state_row) | ||
|
|
||
| for input_data_row in input_data_iter: | ||
| yield (batch_key, input_data_row, None) | ||
| init_result = extract_rows(batch, "initState", self.init_key_offsets) | ||
| if init_result is not None: | ||
| for key, init_state_row in init_result: | ||
| yield (key, None, init_state_row) | ||
|
|
||
| _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) | ||
| data_batches = generate_data_batches(_batches) | ||
|
|
||
| for k, g in groupby(data_batches, key=lambda x: x[0]): | ||
| # g: list(batch_key, input_data_iter, init_state_iter) | ||
|
|
||
| # they are sharing the iterator, hence need to copy | ||
| input_values_iter, init_state_iter = itertools.tee(g, 2) | ||
|
|
||
| chained_input_values = itertools.chain(map(lambda x: x[1], input_values_iter)) | ||
| chained_init_state_values = itertools.chain(map(lambda x: x[2], init_state_iter)) | ||
|
|
||
| chained_input_values_without_none = filter( | ||
| lambda x: x is not None, chained_input_values | ||
| ) | ||
| chained_init_state_values_without_none = filter( | ||
| lambda x: x is not None, chained_init_state_values | ||
| ) | ||
|
|
||
| ret_tuple = (chained_input_values_without_none, chained_init_state_values_without_none) | ||
|
|
||
| yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) | ||
| input_rows = [] | ||
| init_rows = [] | ||
|
|
||
| for batch_key, input_row, init_row in g: | ||
| if input_row is not None: | ||
| input_rows.append(input_row) | ||
| if init_row is not None: | ||
| init_rows.append(init_row) | ||
|
|
||
| total_len = len(input_rows) + len(init_rows) | ||
| if total_len >= self.arrow_max_records_per_batch: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The SQLConf config param says if set to zero or negative number there is no limit, in this case if it's set to zero or a negative number we will always output a fresh batch per row. Let's change the behaviour and add a test covering this.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, right;
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| ret_tuple = (iter(input_rows), iter(init_rows)) | ||
| yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) | ||
| input_rows = [] | ||
| init_rows = [] | ||
|
|
||
| if input_rows or init_rows: | ||
| ret_tuple = (iter(input_rows), iter(init_rows)) | ||
| yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) | ||
|
|
||
| yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given we've changed the implicit type signature of the function lets maybe add a type annotation on generate_data_batches for readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done