Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 133 additions & 79 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,9 @@ def __init__(
self.result_state_pdf_arrow_type = to_arrow_type(
self.result_state_df_type, prefers_large_types=prefers_large_var_types
)
self.arrow_max_records_per_batch = arrow_max_records_per_batch
self.arrow_max_records_per_batch = (
arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
)

def load_stream(self, stream):
"""
Expand Down Expand Up @@ -1821,13 +1823,29 @@ def __init__(
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
arrow_cast=True,
)
self.arrow_max_records_per_batch = arrow_max_records_per_batch
self.arrow_max_records_per_batch = (
arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
)
self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch
self.key_offsets = None
self.average_arrow_row_size = 0
self.total_bytes = 0
self.total_rows = 0

def _update_batch_size_stats(self, batch):
"""
Update batch size statistics for adaptive batching.
"""
# Short circuit batch size calculation if the batch size is
# unlimited as computing batch size is computationally expensive.
if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0:
batch_bytes = sum(
buf.size for col in batch.columns for buf in col.buffers() if buf is not None
)
self.total_bytes += batch_bytes
self.total_rows += batch.num_rows
self.average_arrow_row_size = self.total_bytes / self.total_rows

def load_stream(self, stream):
"""
Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunk, and
Expand Down Expand Up @@ -1855,18 +1873,7 @@ def generate_data_batches(batches):

def row_stream():
for batch in batches:
# Short circuit batch size calculation if the batch size is
# unlimited as computing batch size is computationally expensive.
if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0:
batch_bytes = sum(
buf.size
for col in batch.columns
for buf in col.buffers()
if buf is not None
)
self.total_bytes += batch_bytes
self.total_rows += batch.num_rows
self.average_arrow_row_size = self.total_bytes / self.total_rows
self._update_batch_size_stats(batch)
data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns())
Expand Down Expand Up @@ -1946,6 +1953,7 @@ def __init__(

def load_stream(self, stream):
import pyarrow as pa
import pandas as pd
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)
Expand All @@ -1964,6 +1972,12 @@ def generate_data_batches(batches):

def flatten_columns(cur_batch, col_name):
state_column = cur_batch.column(cur_batch.schema.get_field_index(col_name))

# Check if the entire column is null
if state_column.null_count == len(state_column):
# Return empty table with no columns
return pa.Table.from_arrays([], names=[])

state_field_names = [
state_column.type[i].name for i in range(state_column.type.num_fields)
]
Expand All @@ -1981,30 +1995,67 @@ def flatten_columns(cur_batch, col_name):
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python
data generator. All rows in the same batch have the same grouping key.
data generator. Rows in the same batch may have different grouping keys,
but each batch will have either init_data or input_data, not mix.
"""
for batch in batches:
flatten_state_table = flatten_columns(batch, "inputData")
data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(flatten_state_table.itercolumns())
]

flatten_init_table = flatten_columns(batch, "initState")
init_data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(flatten_init_table.itercolumns())
]
key_series = [data_pandas[o] for o in self.key_offsets]
init_key_series = [init_data_pandas[o] for o in self.init_key_offsets]
def row_stream():
for batch in batches:
self._update_batch_size_stats(batch)

if any(s.empty for s in key_series):
# If any row is empty, assign batch_key using init_key_series
batch_key = tuple(s[0] for s in init_key_series)
else:
# If all rows are non-empty, create batch_key from key_series
batch_key = tuple(s[0] for s in key_series)
yield (batch_key, data_pandas, init_data_pandas)
flatten_state_table = flatten_columns(batch, "inputData")
data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(flatten_state_table.itercolumns())
]

if bool(data_pandas):
for row in pd.concat(data_pandas, axis=1).itertuples(index=False):
batch_key = tuple(row[s] for s in self.key_offsets)
yield (batch_key, row, None)
else:
flatten_init_table = flatten_columns(batch, "initState")
init_data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(flatten_init_table.itercolumns())
]
if bool(init_data_pandas):
for row in pd.concat(init_data_pandas, axis=1).itertuples(index=False):
batch_key = tuple(row[s] for s in self.init_key_offsets)
yield (batch_key, None, row)

EMPTY_DATAFRAME = pd.DataFrame()
for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]):
rows = []
init_state_rows = []
for _, row, init_state_row in group_rows:
if row is not None:
rows.append(row)
if init_state_row is not None:
init_state_rows.append(init_state_row)

total_len = len(rows) + len(init_state_rows)
if (
total_len >= self.arrow_max_records_per_batch
or total_len * self.average_arrow_row_size >= self.arrow_max_bytes_per_batch
):
yield (
batch_key,
pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(),
pd.DataFrame(init_state_rows)
if len(init_state_rows) > 0
else EMPTY_DATAFRAME.copy(),
)
rows = []
init_state_rows = []
if rows or init_state_rows:
yield (
batch_key,
pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(),
pd.DataFrame(init_state_rows)
if len(init_state_rows) > 0
else EMPTY_DATAFRAME.copy(),
)

_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)
Expand All @@ -2030,7 +2081,9 @@ class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):

def __init__(self, arrow_max_records_per_batch):
super(TransformWithStateInPySparkRowSerializer, self).__init__()
self.arrow_max_records_per_batch = arrow_max_records_per_batch
self.arrow_max_records_per_batch = (
arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
)
self.key_offsets = None

def load_stream(self, stream):
Expand Down Expand Up @@ -2122,13 +2175,13 @@ def __init__(self, arrow_max_records_per_batch):
self.init_key_offsets = None

def load_stream(self, stream):
import itertools
import pyarrow as pa
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)
from typing import Iterator, Any, Optional, Tuple

def generate_data_batches(batches):
def generate_data_batches(batches) -> Iterator[Tuple[Any, Optional[Any], Optional[Any]]]:
"""
Deserialize ArrowRecordBatches and return a generator of Row.
The deserialization logic assumes that Arrow RecordBatches contain the data with the
Expand All @@ -2139,8 +2192,15 @@ def generate_data_batches(batches):
into the data generator.
"""

def extract_rows(cur_batch, col_name, key_offsets):
def extract_rows(
cur_batch, col_name, key_offsets
) -> Optional[Iterator[Tuple[Any, Any]]]:
data_column = cur_batch.column(cur_batch.schema.get_field_index(col_name))

# Check if the entire column is null
if data_column.null_count == len(data_column):
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given we've changed the implicit type signature of the function lets maybe add a type annotation on generate_data_batches for readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


data_field_names = [
data_column.type[i].name for i in range(data_column.type.num_fields)
]
Expand All @@ -2153,68 +2213,62 @@ def extract_rows(cur_batch, col_name, key_offsets):
table = pa.Table.from_arrays(data_field_arrays, names=data_field_names)

if table.num_rows == 0:
return (None, iter([]))
else:
batch_key = tuple(table.column(o)[0].as_py() for o in key_offsets)
return None

rows = []
def row_iterator():
for row_idx in range(table.num_rows):
key = tuple(table.column(o)[row_idx].as_py() for o in key_offsets)
row = DataRow(
*(table.column(i)[row_idx].as_py() for i in range(table.num_columns))
)
rows.append(row)
yield (key, row)

return (batch_key, iter(rows))
return row_iterator()

"""
The arrow batch is written in the schema:
schema: StructType = new StructType()
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python
data generator. All rows in the same batch have the same grouping key.
data generator. Each batch will have either init_data or input_data, not mix.
"""
for batch in batches:
(input_batch_key, input_data_iter) = extract_rows(
batch, "inputData", self.key_offsets
)
(init_batch_key, init_state_iter) = extract_rows(
batch, "initState", self.init_key_offsets
)
# Detect which column has data - each batch contains only one type
input_result = extract_rows(batch, "inputData", self.key_offsets)

if input_batch_key is None:
batch_key = init_batch_key
if input_result is not None:
for key, input_data_row in input_result:
yield (key, input_data_row, None)
else:
batch_key = input_batch_key

for init_state_row in init_state_iter:
yield (batch_key, None, init_state_row)

for input_data_row in input_data_iter:
yield (batch_key, input_data_row, None)
init_result = extract_rows(batch, "initState", self.init_key_offsets)
if init_result is not None:
for key, init_state_row in init_result:
yield (key, None, init_state_row)

_batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)

for k, g in groupby(data_batches, key=lambda x: x[0]):
# g: list(batch_key, input_data_iter, init_state_iter)

# they are sharing the iterator, hence need to copy
input_values_iter, init_state_iter = itertools.tee(g, 2)

chained_input_values = itertools.chain(map(lambda x: x[1], input_values_iter))
chained_init_state_values = itertools.chain(map(lambda x: x[2], init_state_iter))

chained_input_values_without_none = filter(
lambda x: x is not None, chained_input_values
)
chained_init_state_values_without_none = filter(
lambda x: x is not None, chained_init_state_values
)

ret_tuple = (chained_input_values_without_none, chained_init_state_values_without_none)

yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple)
input_rows = []
init_rows = []

for batch_key, input_row, init_row in g:
if input_row is not None:
input_rows.append(input_row)
if init_row is not None:
init_rows.append(init_row)

total_len = len(input_rows) + len(init_rows)
if total_len >= self.arrow_max_records_per_batch:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SQLConf config param says if set to zero or negative number there is no limit, in this case if it's set to zero or a negative number we will always output a fresh batch per row. Let's change the behaviour and add a test covering this.

Copy link
Contributor Author

@nyaapa nyaapa Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, right;
copied that from non-init state handling; 🫠
nice catch!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ret_tuple = (iter(input_rows), iter(init_rows))
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple)
input_rows = []
init_rows = []

if input_rows or init_rows:
ret_tuple = (iter(input_rows), iter(init_rows))
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple)

yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)

Expand Down
Loading