Skip to content

Commit a13edac

Browse files
fix: update Spark plugin for compatibility with Spark 3.x and 4.x (#3350)
Signed-off-by: Kevin Liao <[email protected]> Co-authored-by: Kevin Su <[email protected]>
1 parent 21ffdc3 commit a13edac

File tree

2 files changed

+150
-134
lines changed

2 files changed

+150
-134
lines changed
Lines changed: 93 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib.util
12
import typing
23
from typing import Type
34

@@ -84,84 +85,6 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
8485
return r.all()
8586

8687

87-
classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe")
88-
ClassicDataFrame = classic_ps_dataframe.DataFrame
89-
90-
91-
class ClassicSparkDataFrameSchemaReader(SchemaReader[ClassicDataFrame]):
92-
"""
93-
Implements how Classic SparkDataFrame should be read using the ``open`` method of FlyteSchema
94-
"""
95-
96-
def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
97-
super().__init__(from_path, cols, fmt)
98-
99-
def iter(self, **kwargs) -> typing.Generator[T, None, None]:
100-
raise NotImplementedError("Classic Spark DataFrame reader cannot iterate over individual chunks")
101-
102-
def all(self, **kwargs) -> ClassicDataFrame:
103-
if self._fmt == SchemaFormat.PARQUET:
104-
ctx = FlyteContext.current_context().user_space_params
105-
return ctx.spark_session.read.parquet(self.from_path)
106-
raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently")
107-
108-
109-
class ClassicSparkDataFrameSchemaWriter(SchemaWriter[ClassicDataFrame]):
110-
"""
111-
Implements how Classic SparkDataFrame should be written using ``open`` method of FlyteSchema
112-
"""
113-
114-
def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
115-
super().__init__(to_path, cols, fmt)
116-
117-
def write(self, *dfs: ClassicDataFrame, **kwargs):
118-
if dfs is None or len(dfs) == 0:
119-
return
120-
if len(dfs) > 1:
121-
raise AssertionError("Only a single Classic Spark.DataFrame can be written per variable currently")
122-
if self._fmt == SchemaFormat.PARQUET:
123-
dfs[0].write.mode("overwrite").parquet(self.to_path)
124-
return
125-
raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently")
126-
127-
128-
class ClassicSparkDataFrameTransformer(TypeTransformer[ClassicDataFrame]):
129-
"""
130-
Transforms Classic Spark DataFrame's to and from a Schema (typed/untyped)
131-
"""
132-
133-
def __init__(self):
134-
super().__init__("classic-spark-df-transformer", t=ClassicDataFrame)
135-
136-
@staticmethod
137-
def _get_schema_type() -> SchemaType:
138-
return SchemaType(columns=[])
139-
140-
def get_literal_type(self, t: Type[ClassicDataFrame]) -> LiteralType:
141-
return LiteralType(schema=self._get_schema_type())
142-
143-
def to_literal(
144-
self,
145-
ctx: FlyteContext,
146-
python_val: ClassicDataFrame,
147-
python_type: Type[ClassicDataFrame],
148-
expected: LiteralType,
149-
) -> Literal:
150-
remote_path = ctx.file_access.join(
151-
ctx.file_access.raw_output_prefix,
152-
ctx.file_access.get_random_string(),
153-
)
154-
w = ClassicSparkDataFrameSchemaWriter(to_path=remote_path, cols=None, fmt=SchemaFormat.PARQUET)
155-
w.write(python_val)
156-
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type())))
157-
158-
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ClassicDataFrame]) -> T:
159-
if not (lv and lv.scalar and lv.scalar.schema):
160-
return ClassicDataFrame()
161-
r = ClassicSparkDataFrameSchemaReader(from_path=lv.scalar.schema.uri, cols=None, fmt=SchemaFormat.PARQUET)
162-
return r.all()
163-
164-
16588
# %%
16689
# Registers a handle for Spark DataFrame + Flyte Schema type transition
16790
# This allows open(pyspark.DataFrame) to be an acceptable type
@@ -175,15 +98,98 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
17598
)
17699
)
177100

178-
SchemaEngine.register_handler(
179-
SchemaHandler(
180-
"pyspark.sql.classic.DataFrame-Schema",
181-
ClassicDataFrame,
182-
ClassicSparkDataFrameSchemaReader,
183-
ClassicSparkDataFrameSchemaWriter,
184-
handles_remote_io=True,
185-
)
186-
)
187101
# %%
188102
# This makes pyspark.DataFrame as a supported output/input type with flytekit.
189103
TypeEngine.register(SparkDataFrameTransformer())
104+
105+
# Only for classic pyspark which may not be available in Spark 4+
106+
try:
107+
spec = importlib.util.find_spec("pyspark.sql.classic.dataframe")
108+
except Exception:
109+
spec = None
110+
111+
if spec:
112+
classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe")
113+
ClassicDataFrame = classic_ps_dataframe.DataFrame
114+
115+
class ClassicSparkDataFrameSchemaReader(SchemaReader[ClassicDataFrame]):
116+
"""
117+
Implements how Classic SparkDataFrame should be read using the ``open`` method of FlyteSchema
118+
"""
119+
120+
def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
121+
super().__init__(from_path, cols, fmt)
122+
123+
def iter(self, **kwargs) -> typing.Generator[T, None, None]:
124+
raise NotImplementedError("Classic Spark DataFrame reader cannot iterate over individual chunks")
125+
126+
def all(self, **kwargs) -> ClassicDataFrame:
127+
if self._fmt == SchemaFormat.PARQUET:
128+
ctx = FlyteContext.current_context().user_space_params
129+
return ctx.spark_session.read.parquet(self.from_path)
130+
raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently")
131+
132+
class ClassicSparkDataFrameSchemaWriter(SchemaWriter[ClassicDataFrame]):
133+
"""
134+
Implements how Classic SparkDataFrame should be written using ``open`` method of FlyteSchema
135+
"""
136+
137+
def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
138+
super().__init__(to_path, cols, fmt)
139+
140+
def write(self, *dfs: ClassicDataFrame, **kwargs):
141+
if dfs is None or len(dfs) == 0:
142+
return
143+
if len(dfs) > 1:
144+
raise AssertionError("Only a single Classic Spark.DataFrame can be written per variable currently")
145+
if self._fmt == SchemaFormat.PARQUET:
146+
dfs[0].write.mode("overwrite").parquet(self.to_path)
147+
return
148+
raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently")
149+
150+
class ClassicSparkDataFrameTransformer(TypeTransformer[ClassicDataFrame]):
151+
"""
152+
Transforms Classic Spark DataFrame's to and from a Schema (typed/untyped)
153+
"""
154+
155+
def __init__(self):
156+
super().__init__("classic-spark-df-transformer", t=ClassicDataFrame)
157+
158+
@staticmethod
159+
def _get_schema_type() -> SchemaType:
160+
return SchemaType(columns=[])
161+
162+
def get_literal_type(self, t: Type[ClassicDataFrame]) -> LiteralType:
163+
return LiteralType(schema=self._get_schema_type())
164+
165+
def to_literal(
166+
self,
167+
ctx: FlyteContext,
168+
python_val: ClassicDataFrame,
169+
python_type: Type[ClassicDataFrame],
170+
expected: LiteralType,
171+
) -> Literal:
172+
remote_path = ctx.file_access.join(
173+
ctx.file_access.raw_output_prefix,
174+
ctx.file_access.get_random_string(),
175+
)
176+
w = ClassicSparkDataFrameSchemaWriter(to_path=remote_path, cols=None, fmt=SchemaFormat.PARQUET)
177+
w.write(python_val)
178+
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type())))
179+
180+
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ClassicDataFrame]) -> T:
181+
if not (lv and lv.scalar and lv.scalar.schema):
182+
return ClassicDataFrame()
183+
r = ClassicSparkDataFrameSchemaReader(from_path=lv.scalar.schema.uri, cols=None, fmt=SchemaFormat.PARQUET)
184+
return r.all()
185+
186+
SchemaEngine.register_handler(
187+
SchemaHandler(
188+
"pyspark.sql.classic.DataFrame-Schema",
189+
ClassicDataFrame,
190+
ClassicSparkDataFrameSchemaReader,
191+
ClassicSparkDataFrameSchemaWriter,
192+
handles_remote_io=True,
193+
)
194+
)
195+
TypeEngine.register(ClassicSparkDataFrameTransformer())

plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py

Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib.util
12
import typing
23

34
from flytekit import FlyteContext, lazy_module
@@ -14,6 +15,8 @@
1415

1516
pd = lazy_module("pandas")
1617
pyspark = lazy_module("pyspark")
18+
19+
# Base Spark DataFrame (Spark 3.x or Spark 4 parent)
1720
ps_dataframe = lazy_module("pyspark.sql.dataframe")
1821
DataFrame = ps_dataframe.DataFrame
1922

@@ -47,7 +50,9 @@ def encode(
4750
df = typing.cast(DataFrame, structured_dataset.dataframe)
4851
ss = pyspark.sql.SparkSession.builder.getOrCreate()
4952
# Avoid generating SUCCESS files
53+
5054
ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")
55+
5156
df.write.mode("overwrite").parquet(path=path)
5257
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))
5358

@@ -73,51 +78,56 @@ def decode(
7378
StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler())
7479
StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer())
7580

76-
classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe")
77-
ClassicDataFrame = classic_ps_dataframe.DataFrame
78-
79-
80-
class ClassicSparkToParquetEncodingHandler(StructuredDatasetEncoder):
81-
def __init__(self):
82-
super().__init__(ClassicDataFrame, None, PARQUET)
83-
84-
def encode(
85-
self,
86-
ctx: FlyteContext,
87-
structured_dataset: StructuredDataset,
88-
structured_dataset_type: StructuredDatasetType,
89-
) -> literals.StructuredDataset:
90-
path = typing.cast(str, structured_dataset.uri)
91-
if not path:
92-
path = ctx.file_access.join(
93-
ctx.file_access.raw_output_prefix,
94-
ctx.file_access.get_random_string(),
95-
)
96-
df = typing.cast(ClassicDataFrame, structured_dataset.dataframe)
97-
ss = pyspark.sql.SparkSession.builder.getOrCreate()
98-
ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")
99-
df.write.mode("overwrite").parquet(path=path)
100-
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))
101-
102-
103-
class ParquetToClassicSparkDecodingHandler(StructuredDatasetDecoder):
104-
def __init__(self):
105-
super().__init__(ClassicDataFrame, None, PARQUET)
106-
107-
def decode(
108-
self,
109-
ctx: FlyteContext,
110-
flyte_value: literals.StructuredDataset,
111-
current_task_metadata: StructuredDatasetMetadata,
112-
) -> ClassicDataFrame:
113-
user_ctx = FlyteContext.current_context().user_space_params
114-
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
115-
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
116-
return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns)
117-
return user_ctx.spark_session.read.parquet(flyte_value.uri)
118-
11981

120-
# Register the handlers
121-
StructuredDatasetTransformerEngine.register(ClassicSparkToParquetEncodingHandler())
122-
StructuredDatasetTransformerEngine.register(ParquetToClassicSparkDecodingHandler())
123-
StructuredDatasetTransformerEngine.register_renderer(ClassicDataFrame, SparkDataFrameRenderer())
82+
# Only for classic pyspark which may not be available in Spark 4+
83+
try:
84+
spec = importlib.util.find_spec("pyspark.sql.classic.dataframe")
85+
except Exception:
86+
spec = None
87+
if spec:
88+
# Spark 4 "classic" concrete DataFrame, if available
89+
classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe")
90+
ClassicDataFrame = classic_ps_dataframe.DataFrame
91+
92+
class ClassicSparkToParquetEncodingHandler(StructuredDatasetEncoder):
93+
def __init__(self):
94+
super().__init__(ClassicDataFrame, None, PARQUET)
95+
96+
def encode(
97+
self,
98+
ctx: FlyteContext,
99+
structured_dataset: StructuredDataset,
100+
structured_dataset_type: StructuredDatasetType,
101+
) -> literals.StructuredDataset:
102+
path = typing.cast(str, structured_dataset.uri)
103+
if not path:
104+
path = ctx.file_access.join(
105+
ctx.file_access.raw_output_prefix,
106+
ctx.file_access.get_random_string(),
107+
)
108+
df = typing.cast(ClassicDataFrame, structured_dataset.dataframe)
109+
ss = pyspark.sql.SparkSession.builder.getOrCreate()
110+
ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")
111+
df.write.mode("overwrite").parquet(path=path)
112+
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))
113+
114+
class ParquetToClassicSparkDecodingHandler(StructuredDatasetDecoder):
115+
def __init__(self):
116+
super().__init__(ClassicDataFrame, None, PARQUET)
117+
118+
def decode(
119+
self,
120+
ctx: FlyteContext,
121+
flyte_value: literals.StructuredDataset,
122+
current_task_metadata: StructuredDatasetMetadata,
123+
) -> ClassicDataFrame:
124+
user_ctx = FlyteContext.current_context().user_space_params
125+
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
126+
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
127+
return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns)
128+
return user_ctx.spark_session.read.parquet(flyte_value.uri)
129+
130+
# Register the handlers
131+
StructuredDatasetTransformerEngine.register(ClassicSparkToParquetEncodingHandler())
132+
StructuredDatasetTransformerEngine.register(ParquetToClassicSparkDecodingHandler())
133+
StructuredDatasetTransformerEngine.register_renderer(ClassicDataFrame, SparkDataFrameRenderer())

0 commit comments

Comments
 (0)