Skip to content

Commit 551b922

Browse files
committed
[SPARK-54330][PYTHON] Optimize Py4J calls in spark.createDataFrame
### What changes were proposed in this pull request? Optimize Py4J calls in `spark.createDataFrame` ### Why are the changes needed? there are multiple configs in `spark.createDataFrame`, and they are fetched one by one. We should minimize the number of py4j calls in the python side. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #53031 from zhengruifeng/py4j_create_df. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 9f1bd47 commit 551b922

File tree

1 file changed

+53
-16
lines changed

1 file changed

+53
-16
lines changed

python/pyspark/sql/pandas/conversion.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
_create_row,
4242
StringType,
4343
)
44-
from pyspark.sql.utils import is_timestamp_ntz_preferred
4544
from pyspark.traceback_utils import SCCallSiteSync
4645
from pyspark.errors import PySparkTypeError, PySparkValueError
4746

@@ -400,7 +399,28 @@ def createDataFrame( # type: ignore[misc]
400399

401400
assert isinstance(self, SparkSession)
402401

403-
timezone = self._jconf.sessionLocalTimeZone()
402+
(
403+
timestampType,
404+
sessionLocalTimeZone,
405+
arrowPySparkEnabled,
406+
arrowUseLargeVarTypes,
407+
arrowPySparkFallbackEnabled,
408+
arrowMaxRecordsPerBatch,
409+
) = self._jconf.getConfs(
410+
[
411+
"spark.sql.timestampType",
412+
"spark.sql.session.timeZone",
413+
"spark.sql.execution.arrow.pyspark.enabled",
414+
"spark.sql.execution.arrow.useLargeVarTypes",
415+
"spark.sql.execution.arrow.pyspark.fallback.enabled",
416+
"spark.sql.execution.arrow.maxRecordsPerBatch",
417+
]
418+
)
419+
420+
prefer_timestamp_ntz = timestampType == "TIMESTAMP_NTZ"
421+
prefers_large_var_types = arrowUseLargeVarTypes == "true"
422+
timezone = sessionLocalTimeZone
423+
arrow_batch_size = int(arrowMaxRecordsPerBatch)
404424

405425
if type(data).__name__ == "Table":
406426
# `data` is a PyArrow Table
@@ -416,7 +436,7 @@ def createDataFrame( # type: ignore[misc]
416436
if schema is None:
417437
schema = data.schema.names
418438

419-
return self._create_from_arrow_table(data, schema, timezone)
439+
return self._create_from_arrow_table(data, schema, timezone, prefer_timestamp_ntz)
420440

421441
# `data` is a PandasDataFrameLike object
422442
from pyspark.sql.pandas.utils import require_minimum_pandas_version
@@ -427,11 +447,18 @@ def createDataFrame( # type: ignore[misc]
427447
if schema is None:
428448
schema = [str(x) if not isinstance(x, str) else x for x in data.columns]
429449

430-
if self._jconf.arrowPySparkEnabled() and len(data) > 0:
450+
if arrowPySparkEnabled == "true" and len(data) > 0:
431451
try:
432-
return self._create_from_pandas_with_arrow(data, schema, timezone)
452+
return self._create_from_pandas_with_arrow(
453+
data,
454+
schema,
455+
timezone,
456+
prefer_timestamp_ntz,
457+
prefers_large_var_types,
458+
arrow_batch_size,
459+
)
433460
except Exception as e:
434-
if self._jconf.arrowPySparkFallbackEnabled():
461+
if arrowPySparkFallbackEnabled == "true":
435462
msg = (
436463
"createDataFrame attempted Arrow optimization because "
437464
"'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
@@ -451,11 +478,15 @@ def createDataFrame( # type: ignore[misc]
451478
)
452479
warn(msg)
453480
raise
454-
converted_data = self._convert_from_pandas(data, schema, timezone)
481+
converted_data = self._convert_from_pandas(data, schema, timezone, prefer_timestamp_ntz)
455482
return self._create_dataframe(converted_data, schema, samplingRatio, verifySchema)
456483

457484
def _convert_from_pandas(
458-
self, pdf: "PandasDataFrameLike", schema: Union[StructType, str, List[str]], timezone: str
485+
self,
486+
pdf: "PandasDataFrameLike",
487+
schema: Union[StructType, str, List[str]],
488+
timezone: str,
489+
prefer_timestamp_ntz: bool,
459490
) -> List:
460491
"""
461492
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
@@ -571,7 +602,7 @@ def convert_timestamp(value: Any) -> Any:
571602
)
572603
copied = True
573604
else:
574-
should_localize = not is_timestamp_ntz_preferred()
605+
should_localize = not prefer_timestamp_ntz
575606
for column, series in pdf.items():
576607
s = series
577608
if (
@@ -648,7 +679,13 @@ def _get_numpy_record_dtype(self, rec: "np.recarray") -> Optional["np.dtype"]:
648679
return np.dtype(record_type_list) if has_rec_fix else None
649680

650681
def _create_from_pandas_with_arrow(
651-
self, pdf: "PandasDataFrameLike", schema: Union[StructType, List[str]], timezone: str
682+
self,
683+
pdf: "PandasDataFrameLike",
684+
schema: Union[StructType, List[str]],
685+
timezone: str,
686+
prefer_timestamp_ntz: bool,
687+
prefers_large_var_types: bool,
688+
arrow_batch_size: int,
652689
) -> "DataFrame":
653690
"""
654691
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
@@ -688,7 +725,6 @@ def _create_from_pandas_with_arrow(
688725
# Create the Spark schema from list of names passed in with Arrow types
689726
if isinstance(schema, (list, tuple)):
690727
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
691-
prefer_timestamp_ntz = is_timestamp_ntz_preferred()
692728
struct = StructType()
693729
if infer_pandas_dict_as_map:
694730
spark_type: Union[MapType, DataType]
@@ -734,12 +770,11 @@ def _create_from_pandas_with_arrow(
734770
]
735771

736772
# Slice the DataFrame to be batched
737-
step = self._jconf.arrowMaxRecordsPerBatch()
773+
step = arrow_batch_size
738774
step = step if step > 0 else len(pdf)
739775
pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step))
740776

741777
# Create list of Arrow (columns, arrow_type, spark_type) for serializer dump_stream
742-
prefers_large_var_types = self._jconf.arrowUseLargeVarTypes()
743778
arrow_data = [
744779
[
745780
(
@@ -776,7 +811,11 @@ def create_iter_server():
776811
return df
777812

778813
def _create_from_arrow_table(
779-
self, table: "pa.Table", schema: Union[StructType, List[str]], timezone: str
814+
self,
815+
table: "pa.Table",
816+
schema: Union[StructType, List[str]],
817+
timezone: str,
818+
prefer_timestamp_ntz: bool,
780819
) -> "DataFrame":
781820
"""
782821
Create a DataFrame from a given pyarrow.Table by slicing it into partitions then
@@ -798,8 +837,6 @@ def _create_from_arrow_table(
798837

799838
require_minimum_pyarrow_version()
800839

801-
prefer_timestamp_ntz = is_timestamp_ntz_preferred()
802-
803840
# Create the Spark schema from list of names passed in with Arrow types
804841
if isinstance(schema, (list, tuple)):
805842
table = table.rename_columns(schema)

0 commit comments

Comments
 (0)