Skip to content
Closed
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d1fe3da
init
shujingyang-db Aug 27, 2025
7146dd8
ckp
shujingyang-db Aug 28, 2025
b3f2a94
fix
shujingyang-db Aug 28, 2025
fad6256
repartitionById ckp
shujingyang-db Aug 28, 2025
7be523b
Merge remote-tracking branch 'spark/master' into direct-partitionId-p…
shujingyang-db Aug 28, 2025
53ce88a
add more tests
shujingyang-db Aug 29, 2025
84bafd8
add todos
shujingyang-db Aug 29, 2025
6a13e3b
Update PlannerSuite.scala
shujingyang-db Aug 29, 2025
228ca21
Update PlannerSuite.scala
shujingyang-db Aug 29, 2025
599a3d6
Update PlannerSuite.scala
shujingyang-db Aug 29, 2025
643b31a
rm conf
shujingyang-db Aug 29, 2025
695278a
merge conflict
shujingyang-db Aug 29, 2025
e5f4c74
add tests
shujingyang-db Sep 1, 2025
31d4c22
rm todos
shujingyang-db Sep 1, 2025
799549a
clean up
shujingyang-db Sep 1, 2025
4ab3e4b
fix tests
shujingyang-db Sep 1, 2025
ef1fa45
DirectShufflePartitionID
shujingyang-db Sep 1, 2025
97cc15c
Update sql/core/src/test/scala/org/apache/spark/sql/execution/Planner…
shujingyang-db Sep 2, 2025
fddd5da
Update sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.s…
shujingyang-db Sep 2, 2025
e23ea46
Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expr…
shujingyang-db Sep 2, 2025
09ec99b
Update sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.s…
shujingyang-db Sep 2, 2025
5c9f681
int
shujingyang-db Sep 2, 2025
c9b6df5
Merge remote-tracking branch 'origin/direct-partitionId-pass-through'…
shujingyang-db Sep 2, 2025
5d91e0e
Update DirectShufflePartitionID.scala
shujingyang-db Sep 3, 2025
8805934
address comments
shujingyang-db Sep 3, 2025
b0add14
test case for join
shujingyang-db Sep 3, 2025
f678101
Merge remote-tracking branch 'origin/direct-partitionId-pass-through'…
shujingyang-db Sep 3, 2025
13210a1
revert
shujingyang-db Sep 3, 2025
8287ac5
revert
shujingyang-db Sep 3, 2025
1480f86
Apply suggestions from code review
cloud-fan Sep 3, 2025
276a650
Update sql/core/src/test/scala/org/apache/spark/sql/execution/Planner…
cloud-fan Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType}

/**
* Expression that takes a partition ID value and passes it through directly for use in
* shuffle partitioning. This is used with RepartitionByExpression to allow users to
* directly specify target partition IDs.
*
* The child expression must evaluate to an integral type and must not be null.
* The resulting partition ID must be in the range [0, numPartitions).
*/
case class DirectShufflePartitionID(child: Expression)
extends UnaryExpression
with ExpectsInputTypes
with Unevaluable {

override def dataType: DataType = child.dataType

override def inputTypes: Seq[AbstractDataType] = IntegerType :: Nil

override def nullable: Boolean = false

override val prettyName: String = "direct_shuffle_partition_id"

override protected def withNewChildInternal(newChild: Expression): DirectShufflePartitionID =
copy(child = newChild)
}

Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, ShufflePartitionIdPassThrough, SinglePartition}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.types.DataTypeUtils
Expand Down Expand Up @@ -1871,19 +1871,29 @@ trait HasPartitionExpressions extends SQLConfHelper {
protected def partitioning: Partitioning = if (partitionExpressions.isEmpty) {
RoundRobinPartitioning(numPartitions)
} else {
val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder])
require(sortOrder.isEmpty || nonSortOrder.isEmpty,
s"${getClass.getSimpleName} expects that either all its `partitionExpressions` are of type " +
"`SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`, which " +
"means `HashPartitioning`. In this case we have:" +
s"""
|SortOrder: $sortOrder
|NonSortOrder: $nonSortOrder
""".stripMargin)
if (sortOrder.nonEmpty) {
RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions)
val directShuffleExprs = partitionExpressions.filter(_.isInstanceOf[DirectShufflePartitionID])
if (directShuffleExprs.nonEmpty) {
assert(directShuffleExprs.length == 1 && partitionExpressions.length == 1,
s"DirectShufflePartitionID can only be used as a single partition expression, " +
s"but found ${directShuffleExprs.length} DirectShufflePartitionID expressions " +
s"out of ${partitionExpressions.length} total expressions")
ShufflePartitionIdPassThrough(
partitionExpressions.head.asInstanceOf[DirectShufflePartitionID], numPartitions)
} else {
HashPartitioning(partitionExpressions, numPartitions)
val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder])
require(sortOrder.isEmpty || nonSortOrder.isEmpty,
s"${getClass.getSimpleName} expects that either all its `partitionExpressions` are of" +
" type `SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`," +
" which means `HashPartitioning`. In this case we have:" +
s"""
|SortOrder: $sortOrder
|NonSortOrder: $nonSortOrder
""".stripMargin)
if (sortOrder.nonEmpty) {
RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions)
} else {
HashPartitioning(partitionExpressions, numPartitions)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,47 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
* - Creating a partitioning that can be used to re-partition another child, so that to make it
* having a compatible partitioning as this node.
*/

/**
* Represents a partitioning where partition IDs are passed through directly from the
* DirectShufflePartitionID expression. This partitioning scheme is used when users
* want to directly control partition placement rather than using hash-based partitioning.
*
* This partitioning maps directly to the PartitionIdPassthrough RDD partitioner.
*/
case class ShufflePartitionIdPassThrough(
expr: DirectShufflePartitionID,
numPartitions: Int) extends Expression with Partitioning with Unevaluable {

// TODO(SPARK-53401): Support Shuffle Spec in Direct Partition ID Pass Through
def partitionIdExpression: Expression = Pmod(expr.child, Literal(numPartitions))

def expressions: Seq[Expression] = expr :: Nil
override def children: Seq[Expression] = expr :: Nil
override def nullable: Boolean = false
override def dataType: DataType = IntegerType

override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
required match {
// TODO(SPARK-53428): Support Direct Passthrough Partitioning in the Streaming Joins
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
val partitioningExpressions = expr.child :: Nil
if (requireAllClusterKeys) {
c.areAllClusterKeysMatched(partitioningExpressions)
} else {
partitioningExpressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
case _ => false
}
}
}

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): ShufflePartitionIdPassThrough =
copy(expr = newChildren.head.asInstanceOf[DirectShufflePartitionID])
}

trait ShuffleSpec {
/**
* Returns the number of partitions of this shuffle spec
Expand Down
15 changes: 15 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,21 @@ class Dataset[T] private[sql](
}
}

/**
* Repartitions the Dataset into the given number of partitions using the specified
* partition ID expression.
*
* @param numPartitions the number of partitions to use.
* @param partitionIdExpr the expression to be used as the partition ID. Must be an integer type.
*
* @group typedrel
* @since 4.1.0
*/
def repartitionById(numPartitions: Int, partitionIdExpr: Column): Dataset[T] = {
val directShufflePartitionIdCol = Column(DirectShufflePartitionID(partitionIdExpr.expr))
repartitionByExpression(Some(numPartitions), Seq(directShufflePartitionIdCol))
}

protected def repartitionByRange(
numPartitions: Option[Int],
partitionExprs: Seq[Column]): Dataset[T] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@ object ShuffleExchangeExec {
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
new PartitionIdPassthrough(n)
case ShufflePartitionIdPassThrough(_, n) =>
// For ShufflePartitionIdPassThrough, the DirectShufflePartitionID expression directly
// produces partition IDs, so we use PartitionIdPassthrough to pass them through directly.
new PartitionIdPassthrough(n)
case RangePartitioning(sortingExpressions, numPartitions) =>
// Extract only fields used for sorting to avoid collecting large fields that does not
// affect sorting result when deciding partition bounds in RangePartitioner
Expand Down Expand Up @@ -399,6 +403,11 @@ object ShuffleExchangeExec {
case SinglePartition => identity
case KeyGroupedPartitioning(expressions, _, _, _) =>
row => bindReferences(expressions, outputAttributes).map(_.eval(row))
case s: ShufflePartitionIdPassThrough =>
// For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID
// If the value is null, `InternalRow#getInt` returns 0.
val projection = UnsafeProjection.create(s.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, DataFrame, Row}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
Expand All @@ -28,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, REPARTITION_BY_COL, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, EnsureRequirements, REPARTITION_BY_COL, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -1406,6 +1407,182 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
assert(planned.exists(_.isInstanceOf[GlobalLimitExec]))
assert(planned.exists(_.isInstanceOf[LocalLimitExec]))
}

test("SPARK-53401: repartitionById - should partition rows to the specified partition ID") {
val numPartitions = 10
val df = spark.range(100).withColumn("expected_p_id", col("id") % numPartitions)

val repartitioned = df.repartitionById(numPartitions, $"expected_p_id".cast("int"))
val result = repartitioned.withColumn("actual_p_id", spark_partition_id())

assert(result.filter(col("expected_p_id") =!= col("actual_p_id")).count() == 0)

assert(result.rdd.getNumPartitions == numPartitions)
}

test("SPARK-53401: repartitionById should handle negative partition ids correctly with pmod") {
val df = spark.range(10).toDF("id")
val repartitioned = df.repartitionById(10, ($"id" - 5).cast("int"))

// With pmod, negative values should be converted to positive values
// (-5) pmod 10 = 5, (-4) pmod 10 = 6
val result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect()

assert(result.forall(row => {
val actualPartitionId = row.getAs[Int]("actual_p_id")
val id = row.getAs[Long]("id")
val expectedPartitionId = {
val mod = (id - 5) % 10
if (mod < 0) mod + 10 else mod
}
actualPartitionId == expectedPartitionId
}))
}

test("SPARK-53401: repartitionById should fail analysis for non-integral types") {
val df = spark.range(5).withColumn("s", lit("a"))
checkError(
exception = intercept[AnalysisException] {
df.repartitionById(5, $"s").collect()
},
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
parameters = Map(
"sqlExpr" -> "\"direct_shuffle_partition_id(s)\"",
"requiredType" -> "\"INT\"",
"actualType" -> "\"STRING\"",
"inputExpr" -> "\"s\""
)
)
}

test("SPARK-53401: repartitionById should send null partition ids to partition 0") {
val df = spark.range(10).toDF("id")
val partitionExpr = when($"id" < 5, $"id").otherwise(lit(null)).cast("int")
val repartitioned = df.repartitionById(10, partitionExpr)

val result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect()

val nullRows = result.filter(_.getAs[Long]("id") >= 5)
assert(nullRows.nonEmpty, "Should have rows with null partition expression")
assert(nullRows.forall(_.getAs[Int]("actual_p_id") == 0),
"All null partition id rows should go to partition 0")

val nonNullRows = result.filter(_.getAs[Long]("id") < 5)
nonNullRows.foreach { row =>
val id = row.getAs[Long]("id").toInt
val actualPartitionId = row.getAs[Int]("actual_p_id")
assert(actualPartitionId == id % 10,
s"Row with id=$id should be in partition ${id % 10}, " +
s"but was in partition $actualPartitionId")
}
}

test("SPARK-53401: repartitionById should not" +
" throw an exception for partition id >= numPartitions") {
val numPartitions = 10
val df = spark.range(20).toDF("id")
val repartitioned = df.repartitionById(numPartitions, $"id".cast("int"))

assert(repartitioned.collect().length == 20)
assert(repartitioned.rdd.getNumPartitions == numPartitions)
}

/**
* A helper function to check the number of shuffle exchanges in a physical plan.
*
* @param df The DataFrame whose physical plan will be examined.
* @param expectedShuffles The expected number of shuffle exchanges.
*/
private def checkShuffleCount(df: DataFrame, expectedShuffles: Int): Unit = {
val plan = df.queryExecution.executedPlan
val shuffles = collect(plan) {
case s: ShuffleExchangeLike => s
case s: BroadcastExchangeLike => s
}
assert(
shuffles.size == expectedShuffles,
s"Expected $expectedShuffles shuffle(s), but found ${shuffles.size} in the plan:\n$plan"
)
}

test("SPARK-53401: repartitionById followed by groupBy should only have one shuffle") {
val df = spark.range(100)
.withColumn("id", col("id").cast("int"))
.toDF("id")
val repartitioned = df.repartitionById(10, $"id")
val grouped = repartitioned.groupBy($"id").count()

checkShuffleCount(grouped, 1)
}

test("SPARK-53401: groupBy on a superset of partition keys should reuse the shuffle") {
val df = spark.range(100)
.withColumn("id", col("id").cast("int"))
.select($"id" % 10 as "key1", $"id" as "value")
val grouped = df.repartitionById(10, $"key1").groupBy($"key1", lit(1)).count()
checkShuffleCount(grouped, 1)
}

test("SPARK-53401: shuffle reuse is not affected by spark.sql.shuffle.partitions") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
val df = spark.range(100)
.withColumn("id", col("id").cast("int"))
.select($"id" % 10 as "key", $"id" as "value")
val grouped = df.repartitionById(10, $"key").groupBy($"key").count()

checkShuffleCount(grouped, 1)
assert(grouped.rdd.getNumPartitions == 10)
}
}

test("SPARK-53401: join with id pass-through and hash partitioning requires shuffle") {
val df1 = spark.range(100)
.withColumn("id", col("id").cast("int"))
.select($"id" % 10 as "key", $"id" as "v1")
.repartitionById(10, $"key")

val df2 = spark.range(100)
.withColumn("id", col("id").cast("int"))
.select($"id" % 10 as "key", $"id" as "v2")
.repartition($"key")

val joined1 = df1.join(df2, "key")

val grouped = joined1.groupBy("key").count()

// Total shuffles: one for df1, one broadcast for df2, one for groupBy.
// The groupBy reuse the output partitioning after DirectShufflePartitionID.
checkShuffleCount(grouped, 3)

val joined2 = df2.join(df1, "key")

val grouped2 = joined2.groupBy("key").count()

checkShuffleCount(grouped2, 3)
}

test("SPARK-53401: shuffle reuse after a join doesn't preserve partitioning") {
val df1 =
spark
.range(100)
.withColumn("id", col("id").cast("int"))
.select($"id" % 10 as "key", $"id" as "v1")
.repartitionById(10, $"key")
val df2 =
spark
.range(100)
.withColumn("id", col("id").cast("int"))
.select($"id" % 10 as "key", $"id" as "v2")
.repartitionById(10, $"key")

val joined = df1.join(df2, "key")

val grouped = joined.groupBy("key").count()

// Total shuffles: one for df1, one for df2, one for groupBy.
// The groupBy reuse the output partitioning after DirectShufflePartitionID.
checkShuffleCount(grouped, 3)
}
}

// Used for unit-testing EnsureRequirements
Expand Down