1+ import importlib .util
12import typing
23from typing import Type
34
@@ -84,84 +85,6 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
8485 return r .all ()
8586
8687
87- classic_ps_dataframe = lazy_module ("pyspark.sql.classic.dataframe" )
88- ClassicDataFrame = classic_ps_dataframe .DataFrame
89-
90-
91- class ClassicSparkDataFrameSchemaReader (SchemaReader [ClassicDataFrame ]):
92- """
93- Implements how Classic SparkDataFrame should be read using the ``open`` method of FlyteSchema
94- """
95-
96- def __init__ (self , from_path : str , cols : typing .Optional [typing .Dict [str , type ]], fmt : SchemaFormat ):
97- super ().__init__ (from_path , cols , fmt )
98-
99- def iter (self , ** kwargs ) -> typing .Generator [T , None , None ]:
100- raise NotImplementedError ("Classic Spark DataFrame reader cannot iterate over individual chunks" )
101-
102- def all (self , ** kwargs ) -> ClassicDataFrame :
103- if self ._fmt == SchemaFormat .PARQUET :
104- ctx = FlyteContext .current_context ().user_space_params
105- return ctx .spark_session .read .parquet (self .from_path )
106- raise AssertionError ("Only Parquet type files are supported for classic spark dataframe currently" )
107-
108-
109- class ClassicSparkDataFrameSchemaWriter (SchemaWriter [ClassicDataFrame ]):
110- """
111- Implements how Classic SparkDataFrame should be written using ``open`` method of FlyteSchema
112- """
113-
114- def __init__ (self , to_path : str , cols : typing .Optional [typing .Dict [str , type ]], fmt : SchemaFormat ):
115- super ().__init__ (to_path , cols , fmt )
116-
117- def write (self , * dfs : ClassicDataFrame , ** kwargs ):
118- if dfs is None or len (dfs ) == 0 :
119- return
120- if len (dfs ) > 1 :
121- raise AssertionError ("Only a single Classic Spark.DataFrame can be written per variable currently" )
122- if self ._fmt == SchemaFormat .PARQUET :
123- dfs [0 ].write .mode ("overwrite" ).parquet (self .to_path )
124- return
125- raise AssertionError ("Only Parquet type files are supported for classic spark dataframe currently" )
126-
127-
128- class ClassicSparkDataFrameTransformer (TypeTransformer [ClassicDataFrame ]):
129- """
130- Transforms Classic Spark DataFrame's to and from a Schema (typed/untyped)
131- """
132-
133- def __init__ (self ):
134- super ().__init__ ("classic-spark-df-transformer" , t = ClassicDataFrame )
135-
136- @staticmethod
137- def _get_schema_type () -> SchemaType :
138- return SchemaType (columns = [])
139-
140- def get_literal_type (self , t : Type [ClassicDataFrame ]) -> LiteralType :
141- return LiteralType (schema = self ._get_schema_type ())
142-
143- def to_literal (
144- self ,
145- ctx : FlyteContext ,
146- python_val : ClassicDataFrame ,
147- python_type : Type [ClassicDataFrame ],
148- expected : LiteralType ,
149- ) -> Literal :
150- remote_path = ctx .file_access .join (
151- ctx .file_access .raw_output_prefix ,
152- ctx .file_access .get_random_string (),
153- )
154- w = ClassicSparkDataFrameSchemaWriter (to_path = remote_path , cols = None , fmt = SchemaFormat .PARQUET )
155- w .write (python_val )
156- return Literal (scalar = Scalar (schema = Schema (remote_path , self ._get_schema_type ())))
157-
158- def to_python_value (self , ctx : FlyteContext , lv : Literal , expected_python_type : Type [ClassicDataFrame ]) -> T :
159- if not (lv and lv .scalar and lv .scalar .schema ):
160- return ClassicDataFrame ()
161- r = ClassicSparkDataFrameSchemaReader (from_path = lv .scalar .schema .uri , cols = None , fmt = SchemaFormat .PARQUET )
162- return r .all ()
163-
164-
16588# %%
16689# Registers a handle for Spark DataFrame + Flyte Schema type transition
16790# This allows open(pyspark.DataFrame) to be an acceptable type
@@ -175,15 +98,98 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
17598 )
17699)
177100
178- SchemaEngine .register_handler (
179- SchemaHandler (
180- "pyspark.sql.classic.DataFrame-Schema" ,
181- ClassicDataFrame ,
182- ClassicSparkDataFrameSchemaReader ,
183- ClassicSparkDataFrameSchemaWriter ,
184- handles_remote_io = True ,
185- )
186- )
187101# %%
188102# This makes pyspark.DataFrame as a supported output/input type with flytekit.
189103TypeEngine .register (SparkDataFrameTransformer ())
104+
105+ # Only for classic pyspark which may not be available in Spark 4+
106+ try :
107+ spec = importlib .util .find_spec ("pyspark.sql.classic.dataframe" )
108+ except Exception :
109+ spec = None
110+
111+ if spec :
112+ classic_ps_dataframe = lazy_module ("pyspark.sql.classic.dataframe" )
113+ ClassicDataFrame = classic_ps_dataframe .DataFrame
114+
115+ class ClassicSparkDataFrameSchemaReader (SchemaReader [ClassicDataFrame ]):
116+ """
117+ Implements how Classic SparkDataFrame should be read using the ``open`` method of FlyteSchema
118+ """
119+
120+ def __init__ (self , from_path : str , cols : typing .Optional [typing .Dict [str , type ]], fmt : SchemaFormat ):
121+ super ().__init__ (from_path , cols , fmt )
122+
123+ def iter (self , ** kwargs ) -> typing .Generator [T , None , None ]:
124+ raise NotImplementedError ("Classic Spark DataFrame reader cannot iterate over individual chunks" )
125+
126+ def all (self , ** kwargs ) -> ClassicDataFrame :
127+ if self ._fmt == SchemaFormat .PARQUET :
128+ ctx = FlyteContext .current_context ().user_space_params
129+ return ctx .spark_session .read .parquet (self .from_path )
130+ raise AssertionError ("Only Parquet type files are supported for classic spark dataframe currently" )
131+
132+ class ClassicSparkDataFrameSchemaWriter (SchemaWriter [ClassicDataFrame ]):
133+ """
134+ Implements how Classic SparkDataFrame should be written using ``open`` method of FlyteSchema
135+ """
136+
137+ def __init__ (self , to_path : str , cols : typing .Optional [typing .Dict [str , type ]], fmt : SchemaFormat ):
138+ super ().__init__ (to_path , cols , fmt )
139+
140+ def write (self , * dfs : ClassicDataFrame , ** kwargs ):
141+ if dfs is None or len (dfs ) == 0 :
142+ return
143+ if len (dfs ) > 1 :
144+ raise AssertionError ("Only a single Classic Spark.DataFrame can be written per variable currently" )
145+ if self ._fmt == SchemaFormat .PARQUET :
146+ dfs [0 ].write .mode ("overwrite" ).parquet (self .to_path )
147+ return
148+ raise AssertionError ("Only Parquet type files are supported for classic spark dataframe currently" )
149+
150+ class ClassicSparkDataFrameTransformer (TypeTransformer [ClassicDataFrame ]):
151+ """
152+ Transforms Classic Spark DataFrame's to and from a Schema (typed/untyped)
153+ """
154+
155+ def __init__ (self ):
156+ super ().__init__ ("classic-spark-df-transformer" , t = ClassicDataFrame )
157+
158+ @staticmethod
159+ def _get_schema_type () -> SchemaType :
160+ return SchemaType (columns = [])
161+
162+ def get_literal_type (self , t : Type [ClassicDataFrame ]) -> LiteralType :
163+ return LiteralType (schema = self ._get_schema_type ())
164+
165+ def to_literal (
166+ self ,
167+ ctx : FlyteContext ,
168+ python_val : ClassicDataFrame ,
169+ python_type : Type [ClassicDataFrame ],
170+ expected : LiteralType ,
171+ ) -> Literal :
172+ remote_path = ctx .file_access .join (
173+ ctx .file_access .raw_output_prefix ,
174+ ctx .file_access .get_random_string (),
175+ )
176+ w = ClassicSparkDataFrameSchemaWriter (to_path = remote_path , cols = None , fmt = SchemaFormat .PARQUET )
177+ w .write (python_val )
178+ return Literal (scalar = Scalar (schema = Schema (remote_path , self ._get_schema_type ())))
179+
180+ def to_python_value (self , ctx : FlyteContext , lv : Literal , expected_python_type : Type [ClassicDataFrame ]) -> T :
181+ if not (lv and lv .scalar and lv .scalar .schema ):
182+ return ClassicDataFrame ()
183+ r = ClassicSparkDataFrameSchemaReader (from_path = lv .scalar .schema .uri , cols = None , fmt = SchemaFormat .PARQUET )
184+ return r .all ()
185+
186+ SchemaEngine .register_handler (
187+ SchemaHandler (
188+ "pyspark.sql.classic.DataFrame-Schema" ,
189+ ClassicDataFrame ,
190+ ClassicSparkDataFrameSchemaReader ,
191+ ClassicSparkDataFrameSchemaWriter ,
192+ handles_remote_io = True ,
193+ )
194+ )
195+ TypeEngine .register (ClassicSparkDataFrameTransformer ())
0 commit comments