Skip to content

Commit fcde4f9

Browse files
committed
treat non-positive arrowMaxRecordsPerBatch as unlimited
1 parent 64dd204 commit fcde4f9

File tree

4 files changed

+154
-26
lines changed

4 files changed

+154
-26
lines changed

python/pyspark/sql/pandas/serializers.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,7 +1466,7 @@ def __init__(
14661466
self.result_state_pdf_arrow_type = to_arrow_type(
14671467
self.result_state_df_type, prefers_large_types=prefers_large_var_types
14681468
)
1469-
self.arrow_max_records_per_batch = arrow_max_records_per_batch
1469+
self.arrow_max_records_per_batch = arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
14701470

14711471
def load_stream(self, stream):
14721472
"""
@@ -1821,13 +1821,30 @@ def __init__(
18211821
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
18221822
arrow_cast=True,
18231823
)
1824-
self.arrow_max_records_per_batch = arrow_max_records_per_batch
1824+
self.arrow_max_records_per_batch = arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
18251825
self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch
18261826
self.key_offsets = None
18271827
self.average_arrow_row_size = 0
18281828
self.total_bytes = 0
18291829
self.total_rows = 0
18301830

1831+
def _update_batch_size_stats(self, batch):
1832+
"""
1833+
Update batch size statistics for adaptive batching.
1834+
"""
1835+
# Short circuit batch size calculation if the batch size is
1836+
# unlimited as computing batch size is computationally expensive.
1837+
if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0:
1838+
batch_bytes = sum(
1839+
buf.size
1840+
for col in batch.columns
1841+
for buf in col.buffers()
1842+
if buf is not None
1843+
)
1844+
self.total_bytes += batch_bytes
1845+
self.total_rows += batch.num_rows
1846+
self.average_arrow_row_size = self.total_bytes / self.total_rows
1847+
18311848
def load_stream(self, stream):
18321849
"""
18331850
Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunk, and
@@ -1855,18 +1872,7 @@ def generate_data_batches(batches):
18551872

18561873
def row_stream():
18571874
for batch in batches:
1858-
# Short circuit batch size calculation if the batch size is
1859-
# unlimited as computing batch size is computationally expensive.
1860-
if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0:
1861-
batch_bytes = sum(
1862-
buf.size
1863-
for col in batch.columns
1864-
for buf in col.buffers()
1865-
if buf is not None
1866-
)
1867-
self.total_bytes += batch_bytes
1868-
self.total_rows += batch.num_rows
1869-
self.average_arrow_row_size = self.total_bytes / self.total_rows
1875+
self._update_batch_size_stats(batch)
18701876
data_pandas = [
18711877
self.arrow_to_pandas(c, i)
18721878
for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns())
@@ -1993,16 +1999,7 @@ def flatten_columns(cur_batch, col_name):
19931999
"""
19942000
def row_stream():
19952001
for batch in batches:
1996-
if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0:
1997-
batch_bytes = sum(
1998-
buf.size
1999-
for col in batch.columns
2000-
for buf in col.buffers()
2001-
if buf is not None
2002-
)
2003-
self.total_bytes += batch_bytes
2004-
self.total_rows += batch.num_rows
2005-
self.average_arrow_row_size = self.total_bytes / self.total_rows
2002+
self._update_batch_size_stats(batch)
20062003

20072004
flatten_state_table = flatten_columns(batch, "inputData")
20082005
data_pandas = [
@@ -2078,7 +2075,7 @@ class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
20782075

20792076
def __init__(self, arrow_max_records_per_batch):
20802077
super(TransformWithStateInPySparkRowSerializer, self).__init__()
2081-
self.arrow_max_records_per_batch = arrow_max_records_per_batch
2078+
self.arrow_max_records_per_batch = arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
20822079
self.key_offsets = None
20832080

20842081
def load_stream(self, stream):

python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,98 @@ def check_results(batch_df, batch_id):
14831483
),
14841484
)
14851485

1486+
def test_transform_with_state_with_records_limit(self):
1487+
if not self.use_pandas():
1488+
return
1489+
1490+
def make_check_results(expected_per_batch):
1491+
def check_results(batch_df, batch_id):
1492+
batch_df.collect()
1493+
if batch_id == 0:
1494+
assert set(batch_df.sort("id").collect()) == expected_per_batch[0]
1495+
else:
1496+
assert set(batch_df.sort("id").collect()) == expected_per_batch[1]
1497+
1498+
return check_results
1499+
1500+
result_with_small_limit = [
1501+
{
1502+
Row(id="0", chunkCount=2),
1503+
Row(id="1", chunkCount=2),
1504+
},
1505+
{
1506+
Row(id="0", chunkCount=3),
1507+
Row(id="1", chunkCount=2),
1508+
},
1509+
]
1510+
1511+
result_with_large_limit = [
1512+
{
1513+
Row(id="0", chunkCount=1),
1514+
Row(id="1", chunkCount=1),
1515+
},
1516+
{
1517+
Row(id="0", chunkCount=1),
1518+
Row(id="1", chunkCount=1),
1519+
},
1520+
]
1521+
1522+
data = [("0", 789), ("3", 987)]
1523+
initial_state = self.spark.createDataFrame(data, "id string, initVal int").groupBy("id")
1524+
1525+
with self.sql_conf(
1526+
# Set it to a very small number so that every row would be a separate pandas df
1527+
{"spark.sql.execution.arrow.maxRecordsPerBatch": "1"}
1528+
):
1529+
self._test_transform_with_state_basic(
1530+
ChunkCountProcessorFactory(),
1531+
make_check_results(result_with_small_limit),
1532+
output_schema=StructType(
1533+
[
1534+
StructField("id", StringType(), True),
1535+
StructField("chunkCount", IntegerType(), True),
1536+
]
1537+
),
1538+
)
1539+
1540+
self._test_transform_with_state_basic(
1541+
ChunkCountProcessorWithInitialStateFactory(),
1542+
make_check_results(result_with_small_limit),
1543+
initial_state=initial_state,
1544+
output_schema=StructType(
1545+
[
1546+
StructField("id", StringType(), True),
1547+
StructField("chunkCount", IntegerType(), True),
1548+
]
1549+
),
1550+
)
1551+
1552+
with self.sql_conf(
1553+
{"spark.sql.execution.arrow.maxRecordsPerBatch": "-1"}
1554+
):
1555+
self._test_transform_with_state_basic(
1556+
ChunkCountProcessorFactory(),
1557+
make_check_results(result_with_large_limit),
1558+
output_schema=StructType(
1559+
[
1560+
StructField("id", StringType(), True),
1561+
StructField("chunkCount", IntegerType(), True),
1562+
]
1563+
),
1564+
)
1565+
1566+
self._test_transform_with_state_basic(
1567+
ChunkCountProcessorWithInitialStateFactory(),
1568+
make_check_results(result_with_large_limit),
1569+
initial_state=initial_state,
1570+
output_schema=StructType(
1571+
[
1572+
StructField("id", StringType(), True),
1573+
StructField("chunkCount", IntegerType(), True),
1574+
]
1575+
),
1576+
)
1577+
14861578
# test all state types (value, list, map) with large values (512 KB)
14871579
def test_transform_with_state_large_values(self):
14881580
def check_results(batch_df, batch_id):

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class BaseStreamingArrowWriter(
8888

8989
protected def isBatchSizeLimitReached: Boolean = {
9090
// If we have either reached the records or bytes limit
91-
totalNumRowsForBatch >= arrowMaxRecordsPerBatch ||
91+
(arrowMaxRecordsPerBatch > 0 && totalNumRowsForBatch >= arrowMaxRecordsPerBatch) ||
9292
// Short circuit batch size calculation if the batch size is unlimited as computing batch
9393
// size is computationally expensive.
9494
((arrowMaxBytesPerBatch != Int.MaxValue)

sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,43 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite with BeforeAndAfterEac
9595
verify(writer, times(2)).writeBatch()
9696
verify(arrowWriter, times(2)).reset()
9797
}
98+
99+
test("test negative or zero arrowMaxRecordsPerBatch is unlimited") {
100+
val root: VectorSchemaRoot = mock(classOf[VectorSchemaRoot])
101+
val dataRow = mock(classOf[InternalRow])
102+
103+
// Test with negative value
104+
transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(
105+
root, writer, -1, arrowMaxBytesPerBatch, arrowWriter)
106+
107+
// Write many rows (more than typical batch size)
108+
for (_ <- 1 to 10) {
109+
transformWithStateInPySparkWriter.writeRow(dataRow)
110+
}
111+
112+
// Verify all rows were written but batch was not finalized
113+
verify(arrowWriter, times(10)).write(dataRow)
114+
verify(writer, never()).writeBatch()
115+
116+
// Only finalize when explicitly called
117+
transformWithStateInPySparkWriter.finalizeCurrentArrowBatch()
118+
verify(writer).writeBatch()
119+
120+
// Test with zero value
121+
transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(
122+
root, writer, 0, arrowMaxBytesPerBatch, arrowWriter)
123+
124+
// Write many rows again
125+
for (_ <- 1 to 10) {
126+
transformWithStateInPySparkWriter.writeRow(dataRow)
127+
}
128+
129+
// Verify rows were written but batch was not finalized
130+
verify(arrowWriter, times(20)).write(dataRow)
131+
verify(writer).writeBatch() // still 1 from before
132+
133+
// Only finalize when explicitly called
134+
transformWithStateInPySparkWriter.finalizeCurrentArrowBatch()
135+
verify(writer, times(2)).writeBatch()
136+
}
98137
}

0 commit comments

Comments
 (0)