@@ -21,6 +21,8 @@ import java.io.File
2121import java .util
2222import java .util .OptionalLong
2323
24+ import scala .jdk .CollectionConverters ._
25+
2426import test .org .apache .spark .sql .connector ._
2527
2628import org .apache .spark .SparkUnsupportedOperationException
@@ -37,7 +39,7 @@ import org.apache.spark.sql.connector.read.Scan.ColumnarSupportMode
3739import org .apache .spark .sql .connector .read .partitioning .{KeyGroupedPartitioning , Partitioning , UnknownPartitioning }
3840import org .apache .spark .sql .execution .SortExec
3941import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
40- import org .apache .spark .sql .execution .datasources .v2 .{BatchScanExec , DataSourceV2Relation , DataSourceV2ScanRelation }
42+ import org .apache .spark .sql .execution .datasources .v2 .{BatchScanExec , DataSourceV2Relation , DataSourceV2ScanRelation , V2ScanPartitioningAndOrdering }
4143import org .apache .spark .sql .execution .datasources .v2 .DataSourceV2Implicits ._
4244import org .apache .spark .sql .execution .exchange .{Exchange , ShuffleExchangeExec }
4345import org .apache .spark .sql .execution .vectorized .OnHeapColumnVector
@@ -1008,6 +1010,52 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
10081010 " Canonicalized DataSourceV2ScanRelation instances should be equal" )
10091011 }
10101012
1013+ test(" SPARK-54163: scan canonicalization for partitioning and ordering aware data source" ) {
1014+ val options = new CaseInsensitiveStringMap (Map (
1015+ " partitionKeys" -> " i" ,
1016+ " orderKeys" -> " i,j"
1017+ ).asJava)
1018+ val table = new OrderAndPartitionAwareDataSource ().getTable(options)
1019+
1020+ def createDsv2ScanRelation (): DataSourceV2ScanRelation = {
1021+ val relation = DataSourceV2Relation .create(table, None , None , options)
1022+ val scan = relation.table.asReadable.newScanBuilder(relation.options).build()
1023+ val scanRelation = DataSourceV2ScanRelation (relation, scan, relation.output)
1024+ // Attach partitioning and ordering information to DataSourceV2ScanRelation
1025+ V2ScanPartitioningAndOrdering .apply(scanRelation).asInstanceOf [DataSourceV2ScanRelation ]
1026+ }
1027+
1028+ // Create two DataSourceV2ScanRelation instances, representing the scan of the same table
1029+ val scanRelation1 = createDsv2ScanRelation()
1030+ val scanRelation2 = createDsv2ScanRelation()
1031+
1032+ // assert scanRelations have partitioning and ordering
1033+ assert(scanRelation1.keyGroupedPartitioning.isDefined &&
1034+ scanRelation1.keyGroupedPartitioning.get.nonEmpty,
1035+ " DataSourceV2ScanRelation should have key grouped partitioning" )
1036+ assert(scanRelation1.ordering.isDefined && scanRelation1.ordering.get.nonEmpty,
1037+ " DataSourceV2ScanRelation should have ordering" )
1038+
1039+ // the two instances should not be the same, as they should have different attribute IDs
1040+ assert(scanRelation1 != scanRelation2,
1041+ " Two created DataSourceV2ScanRelation instances should not be the same" )
1042+ assert(scanRelation1.output.map(_.exprId).toSet != scanRelation2.output.map(_.exprId).toSet,
1043+ " Output attributes should have different expression IDs before canonicalization" )
1044+ assert(scanRelation1.relation.output.map(_.exprId).toSet !=
1045+ scanRelation2.relation.output.map(_.exprId).toSet,
1046+ " Relation output attributes should have different expression IDs before canonicalization" )
1047+ assert(scanRelation1.keyGroupedPartitioning.get.flatMap(_.references.map(_.exprId)).toSet !=
1048+ scanRelation2.keyGroupedPartitioning.get.flatMap(_.references.map(_.exprId)).toSet,
1049+ " Partitioning columns should have different expression IDs before canonicalization" )
1050+ assert(scanRelation1.ordering.get.flatMap(_.references.map(_.exprId)).toSet !=
1051+ scanRelation2.ordering.get.flatMap(_.references.map(_.exprId)).toSet,
1052+ " Ordering columns should have different expression IDs before canonicalization" )
1053+
1054+ // After canonicalization, the two instances should be equal
1055+ assert(scanRelation1.canonicalized == scanRelation2.canonicalized,
1056+ " Canonicalized DataSourceV2ScanRelation instances should be equal" )
1057+ }
1058+
10111059 test(" SPARK-53809: check mergeScalarSubqueries is effective for DataSourceV2ScanRelation" ) {
10121060 val df = spark.read.format(classOf [SimpleDataSourceV2 ].getName).load()
10131061 df.createOrReplaceTempView(" df" )
@@ -1052,6 +1100,64 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
10521100 // Verify the query produces correct results
10531101 checkAnswer(query, Row (9 , 0 ))
10541102 }
1103+
1104+ test(
1105+ " SPARK-54163: check mergeScalarSubqueries is effective for OrderAndPartitionAwareDataSource"
1106+ ) {
1107+ withSQLConf(SQLConf .V2_BUCKETING_ENABLED .key -> " true" ) {
1108+ val options = Map (
1109+ " partitionKeys" -> " i" ,
1110+ " orderKeys" -> " i,j"
1111+ )
1112+
1113+ // Create the OrderAndPartitionAwareDataSource DataFrame
1114+ val df = spark.read
1115+ .format(classOf [OrderAndPartitionAwareDataSource ].getName)
1116+ .options(options)
1117+ .load()
1118+ df.createOrReplaceTempView(" df" )
1119+
1120+ val query = sql(" select (select max(i) from df) as max_i, (select min(i) from df) as min_i" )
1121+ val optimizedPlan = query.queryExecution.optimizedPlan
1122+
1123+ // check optimizedPlan merged scalar subqueries `select max(i), min(i) from df`
1124+ val sub1 = optimizedPlan.asInstanceOf [Project ].projectList.head.collect {
1125+ case s : ScalarSubquery => s
1126+ }
1127+ val sub2 = optimizedPlan.asInstanceOf [Project ].projectList(1 ).collect {
1128+ case s : ScalarSubquery => s
1129+ }
1130+
1131+ // Both subqueries should reference the same merged plan `select max(i), min(i) from df`
1132+ assert(sub1.nonEmpty && sub2.nonEmpty, " Both scalar subqueries should exist" )
1133+ assert(sub1.head.plan == sub2.head.plan,
1134+ " Both subqueries should reference the same merged plan" )
1135+
1136+ // Extract the aggregate from the merged plan sub1
1137+ val agg = sub1.head.plan.collect {
1138+ case a : Aggregate => a
1139+ }.head
1140+
1141+ // Check that the aggregate contains both max(i) and min(i)
1142+ val aggFunctionSet = agg.aggregateExpressions.flatMap { expr =>
1143+ expr.collect {
1144+ case ae : org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression =>
1145+ ae.aggregateFunction
1146+ }
1147+ }.toSet
1148+
1149+ assert(aggFunctionSet.size == 2 , " Aggregate should contain exactly two aggregate functions" )
1150+ assert(aggFunctionSet
1151+ .exists(_.isInstanceOf [org.apache.spark.sql.catalyst.expressions.aggregate.Max ]),
1152+ " Aggregate should contain max(i)" )
1153+ assert(aggFunctionSet
1154+ .exists(_.isInstanceOf [org.apache.spark.sql.catalyst.expressions.aggregate.Min ]),
1155+ " Aggregate should contain min(i)" )
1156+
1157+ // Verify the query produces correct results
1158+ checkAnswer(query, Row (4 , 1 ))
1159+ }
1160+ }
10551161}
10561162
10571163case class RangeInputPartition (start : Int , end : Int ) extends InputPartition
@@ -1093,6 +1199,18 @@ abstract class SimpleScanBuilder extends ScanBuilder
10931199 override def readSchema (): StructType = TestingV2Source .schema
10941200
10951201 override def createReaderFactory (): PartitionReaderFactory = SimpleReaderFactory
1202+
1203+ override def equals (obj : Any ): Boolean = {
1204+ obj match {
1205+ case s : Scan =>
1206+ this .readSchema() == s.readSchema()
1207+ case _ => false
1208+ }
1209+ }
1210+
1211+ override def hashCode (): Int = {
1212+ this .readSchema().hashCode()
1213+ }
10961214}
10971215
10981216trait TestingV2Source extends TableProvider {
@@ -1157,18 +1275,6 @@ class SimpleDataSourceV2 extends TestingV2Source {
11571275 override def planInputPartitions (): Array [InputPartition ] = {
11581276 Array (RangeInputPartition (0 , 5 ), RangeInputPartition (5 , 10 ))
11591277 }
1160-
1161- override def equals (obj : Any ): Boolean = {
1162- obj match {
1163- case s : Scan =>
1164- this .readSchema() == s.readSchema()
1165- case _ => false
1166- }
1167- }
1168-
1169- override def hashCode (): Int = {
1170- this .readSchema().hashCode()
1171- }
11721278 }
11731279
11741280 override def getTable (options : CaseInsensitiveStringMap ): Table = new SimpleBatchTable {
0 commit comments