Skip to content

Commit 13fea4f

Browse files
chirag-s-dbsunchao
authored andcommitted
[SPARK-54383][SQL] Add precomputed schema variant for InternalRowComparableWrapper util
### What changes were proposed in this pull request? The InternalRowComparableWrapper util is often used in a very hot-path for physical planning (most often, to compare partition values for key-grouped partitioned scans). While the current implementation does schema lookup that each instance uses to create a new instance of this object, this cache lookup itself can become a bottleneck for planning when there are large numbers of partitions. This PR adds a new InternalRowComparableWrapper factory for this util that has a precomputed schema and ordering that can be shared across multiple objects, removing this schema or cache lookup from the hot-path for physical planning. ### Why are the changes needed? Removes a physical planning bottleneck. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This change should not change any behavior (existing tests should suffice). This PR also includes changes to the `InternalRowComparableWrapperBenchmark` to use these new utils. Results before change: ``` [info] internal row comparable wrapper: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] toSet 74 76 2 2.7 367.5 1.0X [info] mergePartitions 136 143 11 1.5 680.0 0.5X [success] Total time: 11 s, completed Nov 17, 2025, 2:29:22 PM ``` Results after change: ``` [info] internal row comparable wrapper: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] toSet 13 13 1 15.9 62.9 1.0X [info] mergePartitions 17 17 1 11.8 84.7 0.7X ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #53097 from chirag-s-db/birc. Authored-by: Chirag Singh <[email protected]> Signed-off-by: Chao Sun <[email protected]>
1 parent 1012a5f commit 13fea4f

File tree

5 files changed

+94
-42
lines changed

5 files changed

+94
-42
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,11 @@ case class KeyGroupedPartitioning(
428428
}
429429

430430
lazy val uniquePartitionValues: Seq[InternalRow] = {
431+
val internalRowComparableFactory =
432+
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
433+
expressions.map(_.dataType))
431434
partitionValues
432-
.map(InternalRowComparableWrapper(_, expressions))
435+
.map(internalRowComparableFactory)
433436
.distinct
434437
.map(_.row)
435438
}
@@ -448,11 +451,14 @@ object KeyGroupedPartitioning {
448451
val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
449452
val projectedOriginalPartitionValues =
450453
originalPartitionValues.map(project(expressions, projectionPositions, _))
454+
val internalRowComparableFactory =
455+
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
456+
projectedExpressions.map(_.dataType))
451457

452458
val finalPartitionValues = projectedPartitionValues
453-
.map(InternalRowComparableWrapper(_, projectedExpressions))
454-
.distinct
455-
.map(_.row)
459+
.map(internalRowComparableFactory)
460+
.distinct
461+
.map(_.row)
456462

457463
KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length,
458464
finalPartitionValues, projectedOriginalPartitionValues)
@@ -867,12 +873,14 @@ case class KeyGroupedShuffleSpec(
867873
// transform functions.
868874
// 4. the partition values from both sides are following the same order.
869875
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) =>
876+
lazy val internalRowComparableFactory =
877+
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
878+
partitioning.expressions.map(_.dataType))
870879
distribution.clustering.length == otherDistribution.clustering.length &&
871880
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
872881
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
873882
case (left, right) =>
874-
InternalRowComparableWrapper(left, partitioning.expressions)
875-
.equals(InternalRowComparableWrapper(right, partitioning.expressions))
883+
internalRowComparableFactory(left).equals(internalRowComparableFactory(right))
876884
}
877885
case ShuffleSpecCollection(specs) =>
878886
specs.exists(isCompatibleWith)
@@ -957,15 +965,16 @@ case class KeyGroupedShuffleSpec(
957965
object KeyGroupedShuffleSpec {
958966
def reducePartitionValue(
959967
row: InternalRow,
960-
expressions: Seq[Expression],
961-
reducers: Seq[Option[Reducer[_, _]]]):
962-
InternalRowComparableWrapper = {
963-
val partitionVals = row.toSeq(expressions.map(_.dataType))
968+
reducers: Seq[Option[Reducer[_, _]]],
969+
dataTypes: Seq[DataType],
970+
internalRowComparableWrapperFactory: InternalRow => InternalRowComparableWrapper
971+
): InternalRowComparableWrapper = {
972+
val partitionVals = row.toSeq(dataTypes)
964973
val reducedRow = partitionVals.zip(reducers).map{
965974
case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v)
966975
case (v, _) => v
967976
}.toArray
968-
InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions)
977+
internalRowComparableWrapperFactory(new GenericInternalRow(reducedRow))
969978
}
970979
}
971980

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util
2020
import scala.collection.mutable
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.catalyst.expressions.{Expression, Murmur3HashFunction, RowOrdering}
23+
import org.apache.spark.sql.catalyst.expressions.{BaseOrdering, Expression, Murmur3HashFunction, RowOrdering}
2424
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition}
2525
import org.apache.spark.sql.types.{DataType, StructField, StructType}
2626
import org.apache.spark.util.NonFateSharingCache
@@ -33,11 +33,23 @@ import org.apache.spark.util.NonFateSharingCache
3333
*
3434
* @param dataTypes the data types for the row
3535
*/
36-
class InternalRowComparableWrapper(val row: InternalRow, val dataTypes: Seq[DataType]) {
37-
import InternalRowComparableWrapper._
36+
class InternalRowComparableWrapper private (
37+
val row: InternalRow,
38+
val dataTypes: Seq[DataType],
39+
val structType: StructType,
40+
val ordering: BaseOrdering) {
3841

39-
private val structType = structTypeCache.get(dataTypes)
40-
private val ordering = orderingCache.get(dataTypes)
42+
/**
43+
* Previous constructor for binary compatibility. Prefer using
44+
* `getInternalRowComparableWrapperFactory` for the creation of InternalRowComparableWrapper's in
45+
* hot paths to avoid excessive cache lookups.
46+
*/
47+
@deprecated
48+
def this(row: InternalRow, dataTypes: Seq[DataType]) = this(
49+
row,
50+
dataTypes,
51+
InternalRowComparableWrapper.structTypeCache.get(dataTypes),
52+
InternalRowComparableWrapper.orderingCache.get(dataTypes))
4153

4254
override def hashCode(): Int = Murmur3HashFunction.hash(
4355
row,
@@ -96,12 +108,14 @@ object InternalRowComparableWrapper {
96108
intersect: Boolean = false): Seq[InternalRowComparableWrapper] = {
97109
val partitionDataTypes = partitionExpression.map(_.dataType)
98110
val leftPartitionSet = new mutable.HashSet[InternalRowComparableWrapper]
111+
val internalRowComparableWrapperFactory =
112+
getInternalRowComparableWrapperFactory(partitionDataTypes)
99113
leftPartitioning
100-
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
114+
.map(internalRowComparableWrapperFactory)
101115
.foreach(partition => leftPartitionSet.add(partition))
102116
val rightPartitionSet = new mutable.HashSet[InternalRowComparableWrapper]
103117
rightPartitioning
104-
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
118+
.map(internalRowComparableWrapperFactory)
105119
.foreach(partition => rightPartitionSet.add(partition))
106120

107121
val result = if (intersect) {
@@ -111,4 +125,12 @@ object InternalRowComparableWrapper {
111125
}
112126
result.toSeq
113127
}
128+
129+
/** Creates a shared factory method for a given row schema to avoid excessive cache lookups. */
130+
def getInternalRowComparableWrapperFactory(
131+
dataTypes: Seq[DataType]): InternalRow => InternalRowComparableWrapper = {
132+
val structType = structTypeCache.get(dataTypes)
133+
val ordering = orderingCache.get(dataTypes)
134+
row: InternalRow => new InternalRowComparableWrapper(row, dataTypes, structType, ordering)
135+
}
114136
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase {
4848
val benchmark = new Benchmark("internal row comparable wrapper", partitionNum, output = output)
4949

5050
benchmark.addCase("toSet") { _ =>
51+
val internalRowComparableWrapperFactory =
52+
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
53+
Seq(IntegerType, IntegerType))
5154
val distinct = partitions
52-
.map(new InternalRowComparableWrapper(_, Seq(IntegerType, IntegerType)))
55+
.map(internalRowComparableWrapperFactory)
5356
.toSet
5457
assert(distinct.size == bucketNum)
5558
}

sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,14 @@ trait KeyGroupedPartitionedScan[T] {
4949
}
5050
case None =>
5151
spjParams.joinKeyPositions match {
52-
case Some(projectionPositions) => basePartitioning.partitionValues.map { r =>
52+
case Some(projectionPositions) =>
53+
val internalRowComparableWrapperFactory =
54+
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
55+
expressions.map(_.dataType))
56+
basePartitioning.partitionValues.map { r =>
5357
val projectedRow = KeyGroupedPartitioning.project(expressions,
5458
projectionPositions, r)
55-
InternalRowComparableWrapper(projectedRow, expressions)
59+
internalRowComparableWrapperFactory(projectedRow)
5660
}.distinct.map(_.row)
5761
case _ => basePartitioning.partitionValues
5862
}
@@ -83,11 +87,14 @@ trait KeyGroupedPartitionedScan[T] {
8387
val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match {
8488
case Some(projectPositions) =>
8589
val projectedExpressions = projectPositions.map(i => expressions(i))
90+
val internalRowComparableWrapperFactory =
91+
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
92+
projectedExpressions.map(_.dataType))
8693
val parts = filteredPartitions.flatten.groupBy(part => {
8794
val row = partitionValueAccessor(part)
8895
val projectedRow = KeyGroupedPartitioning.project(
8996
expressions, projectPositions, row)
90-
InternalRowComparableWrapper(projectedRow, projectedExpressions)
97+
internalRowComparableWrapperFactory(projectedRow)
9198
}).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq
9299
(parts, projectedExpressions)
93100
case _ =>
@@ -99,10 +106,14 @@ trait KeyGroupedPartitionedScan[T] {
99106
}
100107

101108
// Also re-group the partitions if we are reducing compatible partition expressions
109+
val partitionDataTypes = partExpressions.map(_.dataType)
110+
val internalRowComparableWrapperFactory =
111+
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(partitionDataTypes)
102112
val finalGroupedPartitions = spjParams.reducers match {
103113
case Some(reducers) =>
104114
val result = groupedPartitions.groupBy { case (row, _) =>
105-
KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers)
115+
KeyGroupedShuffleSpec.reducePartitionValue(
116+
row, reducers, partitionDataTypes, internalRowComparableWrapperFactory)
106117
}.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq
107118
val rowOrdering = RowOrdering.createNaturalAscendingOrdering(
108119
partExpressions.map(_.dataType))
@@ -118,17 +129,15 @@ trait KeyGroupedPartitionedScan[T] {
118129
// should contain.
119130
val commonPartValuesMap = spjParams.commonPartitionValues
120131
.get
121-
.map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
132+
.map(t => (internalRowComparableWrapperFactory(t._1), t._2))
122133
.toMap
123134
val filteredGroupedPartitions = finalGroupedPartitions.filter {
124135
case (partValues, _) =>
125-
commonPartValuesMap.keySet.contains(
126-
InternalRowComparableWrapper(partValues, partExpressions))
136+
commonPartValuesMap.keySet.contains(internalRowComparableWrapperFactory(partValues))
127137
}
128138
val nestGroupedPartitions = filteredGroupedPartitions.map { case (partValue, splits) =>
129139
// `commonPartValuesMap` should contain the part value since it's the super set.
130-
val numSplits = commonPartValuesMap
131-
.get(InternalRowComparableWrapper(partValue, partExpressions))
140+
val numSplits = commonPartValuesMap.get(internalRowComparableWrapperFactory(partValue))
132141
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
133142
"common partition values from Spark plan")
134143

@@ -143,7 +152,7 @@ trait KeyGroupedPartitionedScan[T] {
143152
// sides of a join will have the same number of partitions & splits.
144153
splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
145154
}
146-
(InternalRowComparableWrapper(partValue, partExpressions), newSplits)
155+
(internalRowComparableWrapperFactory(partValue), newSplits)
147156
}
148157

149158
// Now fill missing partition keys with empty partitions
@@ -152,14 +161,14 @@ trait KeyGroupedPartitionedScan[T] {
152161
case (partValue, numSplits) =>
153162
// Use empty partition for those partition values that are not present.
154163
partitionMapping.getOrElse(
155-
InternalRowComparableWrapper(partValue, partExpressions),
164+
internalRowComparableWrapperFactory(partValue),
156165
Seq.fill(numSplits)(Seq.empty))
157166
}
158167
} else {
159168
// either `commonPartitionValues` is not defined, or it is defined but
160169
// `applyPartialClustering` is false.
161170
val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) =>
162-
InternalRowComparableWrapper(partValue, partExpressions) -> splits
171+
internalRowComparableWrapperFactory(partValue) -> splits
163172
}.toMap
164173

165174
// In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
@@ -168,8 +177,7 @@ trait KeyGroupedPartitionedScan[T] {
168177
// partition values here so that grouped partitions won't get duplicated.
169178
p.uniquePartitionValues.map { partValue =>
170179
// Use empty partition for those partition values that are not present
171-
partitionMapping.getOrElse(
172-
InternalRowComparableWrapper(partValue, partExpressions), Seq.empty)
180+
partitionMapping.getOrElse(internalRowComparableWrapperFactory(partValue), Seq.empty)
173181
}
174182
}
175183
}

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -579,15 +579,18 @@ case class EnsureRequirements(
579579
// In partially clustered distribution, we should use un-grouped partition values
580580
val spec = if (replicateLeftSide) rightSpec else leftSpec
581581
val partValues = spec.partitioning.originalPartitionValues
582+
val internalRowComparableWrapperFactory =
583+
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
584+
partitionExprs.map(_.dataType))
582585

583586
val numExpectedPartitions = partValues
584-
.map(InternalRowComparableWrapper(_, partitionExprs))
587+
.map(internalRowComparableWrapperFactory)
585588
.groupBy(identity)
586589
.transform((_, v) => v.size)
587590

588591
mergedPartValues = mergedPartValues.map { case (partVal, numParts) =>
589592
(partVal, numExpectedPartitions.getOrElse(
590-
InternalRowComparableWrapper(partVal, partitionExprs), numParts))
593+
internalRowComparableWrapperFactory(partVal), numParts))
591594
}
592595

593596
logInfo(log"After applying partially clustered distribution, there are " +
@@ -679,9 +682,15 @@ case class EnsureRequirements(
679682
expressions: Seq[Expression],
680683
reducers: Option[Seq[Option[Reducer[_, _]]]]) = {
681684
reducers match {
682-
case Some(reducers) => partValues.map { row =>
683-
KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers)
684-
}.distinct.map(_.row)
685+
case Some(reducers) =>
686+
val partitionDataTypes = expressions.map(_.dataType)
687+
val internalRowComparableWrapperFactory =
688+
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
689+
partitionDataTypes)
690+
partValues.map { row =>
691+
KeyGroupedShuffleSpec.reducePartitionValue(
692+
row, reducers, partitionDataTypes, internalRowComparableWrapperFactory)
693+
}.distinct.map(_.row)
685694
case _ => partValues
686695
}
687696
}
@@ -737,15 +746,16 @@ case class EnsureRequirements(
737746
rightPartitioning: Seq[InternalRow],
738747
partitionExpression: Seq[Expression],
739748
joinType: JoinType): Seq[InternalRow] = {
749+
val internalRowComparableWrapperFactory =
750+
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
751+
partitionExpression.map(_.dataType))
740752

741753
val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) {
742754
joinType match {
743755
case Inner => InternalRowComparableWrapper.mergePartitions(
744756
leftPartitioning, rightPartitioning, partitionExpression, intersect = true)
745-
case LeftOuter => leftPartitioning.map(
746-
InternalRowComparableWrapper(_, partitionExpression))
747-
case RightOuter => rightPartitioning.map(
748-
InternalRowComparableWrapper(_, partitionExpression))
757+
case LeftOuter => leftPartitioning.map(internalRowComparableWrapperFactory)
758+
case RightOuter => rightPartitioning.map(internalRowComparableWrapperFactory)
749759
case _ => InternalRowComparableWrapper.mergePartitions(leftPartitioning,
750760
rightPartitioning, partitionExpression)
751761
}

0 commit comments

Comments
 (0)