Skip to content

Commit 621c23d

Browse files
committed
format
1 parent fcde4f9 commit 621c23d

File tree

3 files changed

+26
-19
lines changed

3 files changed

+26
-19
lines changed

python/pyspark/sql/pandas/serializers.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,7 +1466,9 @@ def __init__(
14661466
self.result_state_pdf_arrow_type = to_arrow_type(
14671467
self.result_state_df_type, prefers_large_types=prefers_large_var_types
14681468
)
1469-
self.arrow_max_records_per_batch = arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
1469+
self.arrow_max_records_per_batch = (
1470+
arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
1471+
)
14701472

14711473
def load_stream(self, stream):
14721474
"""
@@ -1821,7 +1823,9 @@ def __init__(
18211823
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
18221824
arrow_cast=True,
18231825
)
1824-
self.arrow_max_records_per_batch = arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
1826+
self.arrow_max_records_per_batch = (
1827+
arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
1828+
)
18251829
self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch
18261830
self.key_offsets = None
18271831
self.average_arrow_row_size = 0
@@ -1836,10 +1840,7 @@ def _update_batch_size_stats(self, batch):
18361840
# unlimited as computing batch size is computationally expensive.
18371841
if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0:
18381842
batch_bytes = sum(
1839-
buf.size
1840-
for col in batch.columns
1841-
for buf in col.buffers()
1842-
if buf is not None
1843+
buf.size for col in batch.columns for buf in col.buffers() if buf is not None
18431844
)
18441845
self.total_bytes += batch_bytes
18451846
self.total_rows += batch.num_rows
@@ -1997,6 +1998,7 @@ def flatten_columns(cur_batch, col_name):
19971998
data generator. Rows in the same batch may have different grouping keys,
19981999
but each batch will have either init_data or input_data, not mix.
19992000
"""
2001+
20002002
def row_stream():
20012003
for batch in batches:
20022004
self._update_batch_size_stats(batch)
@@ -2034,21 +2036,25 @@ def row_stream():
20342036

20352037
total_len = len(rows) + len(init_state_rows)
20362038
if (
2037-
total_len >= self.arrow_max_records_per_batch
2038-
or total_len * self.average_arrow_row_size >= self.arrow_max_bytes_per_batch
2039+
total_len >= self.arrow_max_records_per_batch
2040+
or total_len * self.average_arrow_row_size >= self.arrow_max_bytes_per_batch
20392041
):
20402042
yield (
20412043
batch_key,
20422044
pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(),
2043-
pd.DataFrame(init_state_rows) if len(init_state_rows) > 0 else EMPTY_DATAFRAME.copy()
2045+
pd.DataFrame(init_state_rows)
2046+
if len(init_state_rows) > 0
2047+
else EMPTY_DATAFRAME.copy(),
20442048
)
20452049
rows = []
20462050
init_state_rows = []
20472051
if rows or init_state_rows:
20482052
yield (
20492053
batch_key,
20502054
pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(),
2051-
pd.DataFrame(init_state_rows) if len(init_state_rows) > 0 else EMPTY_DATAFRAME.copy()
2055+
pd.DataFrame(init_state_rows)
2056+
if len(init_state_rows) > 0
2057+
else EMPTY_DATAFRAME.copy(),
20522058
)
20532059

20542060
_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
@@ -2075,7 +2081,9 @@ class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
20752081

20762082
def __init__(self, arrow_max_records_per_batch):
20772083
super(TransformWithStateInPySparkRowSerializer, self).__init__()
2078-
self.arrow_max_records_per_batch = arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
2084+
self.arrow_max_records_per_batch = (
2085+
arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
2086+
)
20792087
self.key_offsets = None
20802088

20812089
def load_stream(self, stream):
@@ -2184,7 +2192,9 @@ def generate_data_batches(batches) -> Iterator[Tuple[Any, Optional[Any], Optiona
21842192
into the data generator.
21852193
"""
21862194

2187-
def extract_rows(cur_batch, col_name, key_offsets) -> Optional[Iterator[Tuple[Any, Any]]]:
2195+
def extract_rows(
2196+
cur_batch, col_name, key_offsets
2197+
) -> Optional[Iterator[Tuple[Any, Any]]]:
21882198
data_column = cur_batch.column(cur_batch.schema.get_field_index(col_name))
21892199

21902200
# Check if the entire column is null
@@ -2242,20 +2252,20 @@ def row_iterator():
22422252
for k, g in groupby(data_batches, key=lambda x: x[0]):
22432253
input_rows = []
22442254
init_rows = []
2245-
2255+
22462256
for batch_key, input_row, init_row in g:
22472257
if input_row is not None:
22482258
input_rows.append(input_row)
22492259
if init_row is not None:
22502260
init_rows.append(init_row)
2251-
2261+
22522262
total_len = len(input_rows) + len(init_rows)
22532263
if total_len >= self.arrow_max_records_per_batch:
22542264
ret_tuple = (iter(input_rows), iter(init_rows))
22552265
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple)
22562266
input_rows = []
22572267
init_rows = []
2258-
2268+
22592269
if input_rows or init_rows:
22602270
ret_tuple = (iter(input_rows), iter(init_rows))
22612271
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple)

python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,9 +1549,7 @@ def check_results(batch_df, batch_id):
15491549
),
15501550
)
15511551

1552-
with self.sql_conf(
1553-
{"spark.sql.execution.arrow.maxRecordsPerBatch": "-1"}
1554-
):
1552+
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": "-1"}):
15551553
self._test_transform_with_state_basic(
15561554
ChunkCountProcessorFactory(),
15571555
make_check_results(result_with_large_limit),

python/pyspark/worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3070,7 +3070,6 @@ def values_gen():
30703070
stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)
30713071

30723072
def mapper(a):
3073-
import pandas as pd
30743073
mode = a[0]
30753074

30763075
if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:

0 commit comments

Comments
 (0)