Skip to content

Commit 1012a5f

Browse files
yhuang-dbgengliangwang
authored andcommitted
[SPARK-54163][SQL] Scan canonicalization for partitioning and ordering info
### What changes were proposed in this pull request? This PR extends current canonicalization function for DataSourceV2ScanRelation to normalize the keyGroupedPartitioning and ordering field. Therefore it can apply to partition/ordering-aware data sources. ### Why are the changes needed? In order to apply canonicalization to partition/ordering-aware data sources. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #53105 from yhuang-db/SPARK-54163-canonicalization-normalization. Authored-by: yhuang-db <[email protected]> Signed-off-by: Gengliang Wang <[email protected]>
1 parent 05bc5d4 commit 1012a5f

File tree

2 files changed

+126
-14
lines changed

2 files changed

+126
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,13 @@ case class DataSourceV2ScanRelation(
184184
relation = this.relation.copy(
185185
output = this.relation.output.map(QueryPlan.normalizeExpressions(_, this.relation.output))
186186
),
187-
output = this.output.map(QueryPlan.normalizeExpressions(_, this.output))
187+
output = this.output.map(QueryPlan.normalizeExpressions(_, this.output)),
188+
keyGroupedPartitioning = keyGroupedPartitioning.map(
189+
_.map(QueryPlan.normalizeExpressions(_, output))
190+
),
191+
ordering = ordering.map(
192+
_.map(o => o.copy(child = QueryPlan.normalizeExpressions(o.child, output)))
193+
)
188194
)
189195
}
190196
}

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala

Lines changed: 119 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import java.io.File
2121
import java.util
2222
import java.util.OptionalLong
2323

24+
import scala.jdk.CollectionConverters._
25+
2426
import test.org.apache.spark.sql.connector._
2527

2628
import org.apache.spark.SparkUnsupportedOperationException
@@ -37,7 +39,7 @@ import org.apache.spark.sql.connector.read.Scan.ColumnarSupportMode
3739
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning}
3840
import org.apache.spark.sql.execution.SortExec
3941
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
40-
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation}
42+
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation, V2ScanPartitioningAndOrdering}
4143
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
4244
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
4345
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
@@ -1008,6 +1010,52 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
10081010
"Canonicalized DataSourceV2ScanRelation instances should be equal")
10091011
}
10101012

1013+
test("SPARK-54163: scan canonicalization for partitioning and ordering aware data source") {
1014+
val options = new CaseInsensitiveStringMap(Map(
1015+
"partitionKeys" -> "i",
1016+
"orderKeys" -> "i,j"
1017+
).asJava)
1018+
val table = new OrderAndPartitionAwareDataSource().getTable(options)
1019+
1020+
def createDsv2ScanRelation(): DataSourceV2ScanRelation = {
1021+
val relation = DataSourceV2Relation.create(table, None, None, options)
1022+
val scan = relation.table.asReadable.newScanBuilder(relation.options).build()
1023+
val scanRelation = DataSourceV2ScanRelation(relation, scan, relation.output)
1024+
// Attach partitioning and ordering information to DataSourceV2ScanRelation
1025+
V2ScanPartitioningAndOrdering.apply(scanRelation).asInstanceOf[DataSourceV2ScanRelation]
1026+
}
1027+
1028+
// Create two DataSourceV2ScanRelation instances, representing the scan of the same table
1029+
val scanRelation1 = createDsv2ScanRelation()
1030+
val scanRelation2 = createDsv2ScanRelation()
1031+
1032+
// assert scanRelations have partitioning and ordering
1033+
assert(scanRelation1.keyGroupedPartitioning.isDefined &&
1034+
scanRelation1.keyGroupedPartitioning.get.nonEmpty,
1035+
"DataSourceV2ScanRelation should have key grouped partitioning")
1036+
assert(scanRelation1.ordering.isDefined && scanRelation1.ordering.get.nonEmpty,
1037+
"DataSourceV2ScanRelation should have ordering")
1038+
1039+
// the two instances should not be the same, as they should have different attribute IDs
1040+
assert(scanRelation1 != scanRelation2,
1041+
"Two created DataSourceV2ScanRelation instances should not be the same")
1042+
assert(scanRelation1.output.map(_.exprId).toSet != scanRelation2.output.map(_.exprId).toSet,
1043+
"Output attributes should have different expression IDs before canonicalization")
1044+
assert(scanRelation1.relation.output.map(_.exprId).toSet !=
1045+
scanRelation2.relation.output.map(_.exprId).toSet,
1046+
"Relation output attributes should have different expression IDs before canonicalization")
1047+
assert(scanRelation1.keyGroupedPartitioning.get.flatMap(_.references.map(_.exprId)).toSet !=
1048+
scanRelation2.keyGroupedPartitioning.get.flatMap(_.references.map(_.exprId)).toSet,
1049+
"Partitioning columns should have different expression IDs before canonicalization")
1050+
assert(scanRelation1.ordering.get.flatMap(_.references.map(_.exprId)).toSet !=
1051+
scanRelation2.ordering.get.flatMap(_.references.map(_.exprId)).toSet,
1052+
"Ordering columns should have different expression IDs before canonicalization")
1053+
1054+
// After canonicalization, the two instances should be equal
1055+
assert(scanRelation1.canonicalized == scanRelation2.canonicalized,
1056+
"Canonicalized DataSourceV2ScanRelation instances should be equal")
1057+
}
1058+
10111059
test("SPARK-53809: check mergeScalarSubqueries is effective for DataSourceV2ScanRelation") {
10121060
val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load()
10131061
df.createOrReplaceTempView("df")
@@ -1052,6 +1100,64 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
10521100
// Verify the query produces correct results
10531101
checkAnswer(query, Row(9, 0))
10541102
}
1103+
1104+
test(
1105+
"SPARK-54163: check mergeScalarSubqueries is effective for OrderAndPartitionAwareDataSource"
1106+
) {
1107+
withSQLConf(SQLConf.V2_BUCKETING_ENABLED.key -> "true") {
1108+
val options = Map(
1109+
"partitionKeys" -> "i",
1110+
"orderKeys" -> "i,j"
1111+
)
1112+
1113+
// Create the OrderAndPartitionAwareDataSource DataFrame
1114+
val df = spark.read
1115+
.format(classOf[OrderAndPartitionAwareDataSource].getName)
1116+
.options(options)
1117+
.load()
1118+
df.createOrReplaceTempView("df")
1119+
1120+
val query = sql("select (select max(i) from df) as max_i, (select min(i) from df) as min_i")
1121+
val optimizedPlan = query.queryExecution.optimizedPlan
1122+
1123+
// check optimizedPlan merged scalar subqueries `select max(i), min(i) from df`
1124+
val sub1 = optimizedPlan.asInstanceOf[Project].projectList.head.collect {
1125+
case s: ScalarSubquery => s
1126+
}
1127+
val sub2 = optimizedPlan.asInstanceOf[Project].projectList(1).collect {
1128+
case s: ScalarSubquery => s
1129+
}
1130+
1131+
// Both subqueries should reference the same merged plan `select max(i), min(i) from df`
1132+
assert(sub1.nonEmpty && sub2.nonEmpty, "Both scalar subqueries should exist")
1133+
assert(sub1.head.plan == sub2.head.plan,
1134+
"Both subqueries should reference the same merged plan")
1135+
1136+
// Extract the aggregate from the merged plan sub1
1137+
val agg = sub1.head.plan.collect {
1138+
case a: Aggregate => a
1139+
}.head
1140+
1141+
// Check that the aggregate contains both max(i) and min(i)
1142+
val aggFunctionSet = agg.aggregateExpressions.flatMap { expr =>
1143+
expr.collect {
1144+
case ae: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression =>
1145+
ae.aggregateFunction
1146+
}
1147+
}.toSet
1148+
1149+
assert(aggFunctionSet.size == 2, "Aggregate should contain exactly two aggregate functions")
1150+
assert(aggFunctionSet
1151+
.exists(_.isInstanceOf[org.apache.spark.sql.catalyst.expressions.aggregate.Max]),
1152+
"Aggregate should contain max(i)")
1153+
assert(aggFunctionSet
1154+
.exists(_.isInstanceOf[org.apache.spark.sql.catalyst.expressions.aggregate.Min]),
1155+
"Aggregate should contain min(i)")
1156+
1157+
// Verify the query produces correct results
1158+
checkAnswer(query, Row(4, 1))
1159+
}
1160+
}
10551161
}
10561162

10571163
case class RangeInputPartition(start: Int, end: Int) extends InputPartition
@@ -1093,6 +1199,18 @@ abstract class SimpleScanBuilder extends ScanBuilder
10931199
override def readSchema(): StructType = TestingV2Source.schema
10941200

10951201
override def createReaderFactory(): PartitionReaderFactory = SimpleReaderFactory
1202+
1203+
override def equals(obj: Any): Boolean = {
1204+
obj match {
1205+
case s: Scan =>
1206+
this.readSchema() == s.readSchema()
1207+
case _ => false
1208+
}
1209+
}
1210+
1211+
override def hashCode(): Int = {
1212+
this.readSchema().hashCode()
1213+
}
10961214
}
10971215

10981216
trait TestingV2Source extends TableProvider {
@@ -1157,18 +1275,6 @@ class SimpleDataSourceV2 extends TestingV2Source {
11571275
override def planInputPartitions(): Array[InputPartition] = {
11581276
Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10))
11591277
}
1160-
1161-
override def equals(obj: Any): Boolean = {
1162-
obj match {
1163-
case s: Scan =>
1164-
this.readSchema() == s.readSchema()
1165-
case _ => false
1166-
}
1167-
}
1168-
1169-
override def hashCode(): Int = {
1170-
this.readSchema().hashCode()
1171-
}
11721278
}
11731279

11741280
override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {

0 commit comments

Comments
 (0)