@@ -1466,7 +1466,9 @@ def __init__(
14661466 self .result_state_pdf_arrow_type = to_arrow_type (
14671467 self .result_state_df_type , prefers_large_types = prefers_large_var_types
14681468 )
1469- self .arrow_max_records_per_batch = arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2 ** 31 - 1
1469+ self .arrow_max_records_per_batch = (
1470+ arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2 ** 31 - 1
1471+ )
14701472
14711473 def load_stream (self , stream ):
14721474 """
@@ -1821,7 +1823,9 @@ def __init__(
18211823 int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled ,
18221824 arrow_cast = True ,
18231825 )
1824- self .arrow_max_records_per_batch = arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2 ** 31 - 1
1826+ self .arrow_max_records_per_batch = (
1827+ arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2 ** 31 - 1
1828+ )
18251829 self .arrow_max_bytes_per_batch = arrow_max_bytes_per_batch
18261830 self .key_offsets = None
18271831 self .average_arrow_row_size = 0
@@ -1836,10 +1840,7 @@ def _update_batch_size_stats(self, batch):
18361840 # unlimited as computing batch size is computationally expensive.
18371841 if self .arrow_max_bytes_per_batch != 2 ** 31 - 1 and batch .num_rows > 0 :
18381842 batch_bytes = sum (
1839- buf .size
1840- for col in batch .columns
1841- for buf in col .buffers ()
1842- if buf is not None
1843+ buf .size for col in batch .columns for buf in col .buffers () if buf is not None
18431844 )
18441845 self .total_bytes += batch_bytes
18451846 self .total_rows += batch .num_rows
@@ -1997,6 +1998,7 @@ def flatten_columns(cur_batch, col_name):
19971998 data generator. Rows in the same batch may have different grouping keys,
19981999 but each batch will have either init_data or input_data, not mix.
19992000 """
2001+
20002002 def row_stream ():
20012003 for batch in batches :
20022004 self ._update_batch_size_stats (batch )
@@ -2034,21 +2036,25 @@ def row_stream():
20342036
20352037 total_len = len (rows ) + len (init_state_rows )
20362038 if (
2037- total_len >= self .arrow_max_records_per_batch
2038- or total_len * self .average_arrow_row_size >= self .arrow_max_bytes_per_batch
2039+ total_len >= self .arrow_max_records_per_batch
2040+ or total_len * self .average_arrow_row_size >= self .arrow_max_bytes_per_batch
20392041 ):
20402042 yield (
20412043 batch_key ,
20422044 pd .DataFrame (rows ) if len (rows ) > 0 else EMPTY_DATAFRAME .copy (),
2043- pd .DataFrame (init_state_rows ) if len (init_state_rows ) > 0 else EMPTY_DATAFRAME .copy ()
2045+ pd .DataFrame (init_state_rows )
2046+ if len (init_state_rows ) > 0
2047+ else EMPTY_DATAFRAME .copy (),
20442048 )
20452049 rows = []
20462050 init_state_rows = []
20472051 if rows or init_state_rows :
20482052 yield (
20492053 batch_key ,
20502054 pd .DataFrame (rows ) if len (rows ) > 0 else EMPTY_DATAFRAME .copy (),
2051- pd .DataFrame (init_state_rows ) if len (init_state_rows ) > 0 else EMPTY_DATAFRAME .copy ()
2055+ pd .DataFrame (init_state_rows )
2056+ if len (init_state_rows ) > 0
2057+ else EMPTY_DATAFRAME .copy (),
20522058 )
20532059
20542060 _batches = super (ArrowStreamPandasSerializer , self ).load_stream (stream )
@@ -2075,7 +2081,9 @@ class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
20752081
20762082 def __init__ (self , arrow_max_records_per_batch ):
20772083 super (TransformWithStateInPySparkRowSerializer , self ).__init__ ()
2078- self .arrow_max_records_per_batch = arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2 ** 31 - 1
2084+ self .arrow_max_records_per_batch = (
2085+ arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2 ** 31 - 1
2086+ )
20792087 self .key_offsets = None
20802088
20812089 def load_stream (self , stream ):
@@ -2184,7 +2192,9 @@ def generate_data_batches(batches) -> Iterator[Tuple[Any, Optional[Any], Optiona
21842192 into the data generator.
21852193 """
21862194
2187- def extract_rows (cur_batch , col_name , key_offsets ) -> Optional [Iterator [Tuple [Any , Any ]]]:
2195+ def extract_rows (
2196+ cur_batch , col_name , key_offsets
2197+ ) -> Optional [Iterator [Tuple [Any , Any ]]]:
21882198 data_column = cur_batch .column (cur_batch .schema .get_field_index (col_name ))
21892199
21902200 # Check if the entire column is null
@@ -2242,20 +2252,20 @@ def row_iterator():
22422252 for k , g in groupby (data_batches , key = lambda x : x [0 ]):
22432253 input_rows = []
22442254 init_rows = []
2245-
2255+
22462256 for batch_key , input_row , init_row in g :
22472257 if input_row is not None :
22482258 input_rows .append (input_row )
22492259 if init_row is not None :
22502260 init_rows .append (init_row )
2251-
2261+
22522262 total_len = len (input_rows ) + len (init_rows )
22532263 if total_len >= self .arrow_max_records_per_batch :
22542264 ret_tuple = (iter (input_rows ), iter (init_rows ))
22552265 yield (TransformWithStateInPandasFuncMode .PROCESS_DATA , k , ret_tuple )
22562266 input_rows = []
22572267 init_rows = []
2258-
2268+
22592269 if input_rows or init_rows :
22602270 ret_tuple = (iter (input_rows ), iter (init_rows ))
22612271 yield (TransformWithStateInPandasFuncMode .PROCESS_DATA , k , ret_tuple )
0 commit comments