Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d1fe3da
init
shujingyang-db Aug 27, 2025
7146dd8
ckp
shujingyang-db Aug 28, 2025
b3f2a94
fix
shujingyang-db Aug 28, 2025
fad6256
repartitionById ckp
shujingyang-db Aug 28, 2025
7be523b
Merge remote-tracking branch 'spark/master' into direct-partitionId-p…
shujingyang-db Aug 28, 2025
53ce88a
add more tests
shujingyang-db Aug 29, 2025
84bafd8
add todos
shujingyang-db Aug 29, 2025
6a13e3b
Update PlannerSuite.scala
shujingyang-db Aug 29, 2025
228ca21
Update PlannerSuite.scala
shujingyang-db Aug 29, 2025
599a3d6
Update PlannerSuite.scala
shujingyang-db Aug 29, 2025
643b31a
rm conf
shujingyang-db Aug 29, 2025
695278a
merge conflict
shujingyang-db Aug 29, 2025
e5f4c74
add tests
shujingyang-db Sep 1, 2025
31d4c22
rm todos
shujingyang-db Sep 1, 2025
799549a
clean up
shujingyang-db Sep 1, 2025
4ab3e4b
fix tests
shujingyang-db Sep 1, 2025
ef1fa45
DirectShufflePartitionID
shujingyang-db Sep 1, 2025
97cc15c
Update sql/core/src/test/scala/org/apache/spark/sql/execution/Planner…
shujingyang-db Sep 2, 2025
fddd5da
Update sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.s…
shujingyang-db Sep 2, 2025
e23ea46
Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expr…
shujingyang-db Sep 2, 2025
09ec99b
Update sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.s…
shujingyang-db Sep 2, 2025
5c9f681
int
shujingyang-db Sep 2, 2025
c9b6df5
Merge remote-tracking branch 'origin/direct-partitionId-pass-through'…
shujingyang-db Sep 2, 2025
5d91e0e
Update DirectShufflePartitionID.scala
shujingyang-db Sep 3, 2025
8805934
address comments
shujingyang-db Sep 3, 2025
b0add14
test case for join
shujingyang-db Sep 3, 2025
f678101
Merge remote-tracking branch 'origin/direct-partitionId-pass-through'…
shujingyang-db Sep 3, 2025
13210a1
revert
shujingyang-db Sep 3, 2025
8287ac5
revert
shujingyang-db Sep 3, 2025
1480f86
Apply suggestions from code review
cloud-fan Sep 3, 2025
276a650
Update sql/core/src/test/scala/org/apache/spark/sql/execution/Planner…
cloud-fan Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2045,6 +2045,19 @@ object functions {
*/
def spark_partition_id(): Column = Column.fn("spark_partition_id")

/**
* Returns the partition ID specified by the given column expression for direct shuffle
* partitioning. The input expression must evaluate to an integral type and must not be null.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this partition id be changed by AQE?

*
* This function is used with DataFrame.repartitionByExpr to allow users to directly specify
* target partition IDs instead of using hash-based partitioning.
*
* @group misc_funcs
* @since 4.1.0
*/
def direct_shuffle_partition_id(e: Column): Column =
Column.fn("direct_shuffle_partition_id", e)

/**
* Computes the square root of the specified float value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ object FunctionRegistry {
expression[AesEncrypt]("aes_encrypt"),
expression[AesDecrypt]("aes_decrypt"),
expression[SparkPartitionID]("spark_partition_id"),
expression[DirectShufflePartitionID]("direct_shuffle_partition_id"),
expression[InputFileName]("input_file_name"),
expression[InputFileBlockStart]("input_file_block_start"),
expression[InputFileBlockLength]("input_file_block_length"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{AbstractDataType, DataType, LongType}

/**
* Expression that takes a partition ID value and passes it through directly for use in
* shuffle partitioning. This is used with RepartitionByExpression to allow users to
* directly specify target partition IDs.
*
* The child expression must evaluate to an integral type and must not be null.
* The resulting partition ID must be in the range [0, numPartitions).
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the partition ID specified by expr for direct shuffle " +
"partitioning.",
arguments = """
Arguments:
* expr - an integral expression that specifies the target partition ID
""",
examples = """
Examples:
> df.repartition(10, direct_shuffle_partition_id($"partition_id"))
> df.repartition(10, expr("direct_shuffle_partition_id(id % 5)"))
""",
since = "4.1.0",
group = "misc_funcs"
)
case class DirectShufflePartitionID(child: Expression)
extends UnaryExpression
with ExpectsInputTypes {

override def dataType: DataType = child.dataType

override def inputTypes: Seq[AbstractDataType] = LongType :: Nil

override def nullable: Boolean = false

override val prettyName: String = "direct_shuffle_partition_id"

override def eval(input: InternalRow): Any = {
val result = child.eval(input)
if (result == null) {
throw new IllegalArgumentException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a user-facing error condition for it, or we still treat null as 0.

Copy link
Contributor Author

@shujingyang-db shujingyang-db Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep we treat null as 0. I think it's not able to create and test out a user-facing error condition here as this should only happen if there is an internal error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: removing this as we mark DirectShufflePartitionID as Unevaluable

"The partition ID expression must not be null.")
}
nullSafeEval(result)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val resultCode =
s"""
|${childGen.code}
|if (${childGen.isNull}) {
| throw new IllegalArgumentException(
| "The partition ID expression must not be null.");
|}
|""".stripMargin

ev.copy(code = code"$resultCode", isNull = FalseLiteral, value = childGen.value)
}

override protected def withNewChildInternal(newChild: Expression): DirectShufflePartitionID =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, ShufflePartitionIdPassThrough, SinglePartition}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.types.DataTypeUtils
Expand Down Expand Up @@ -1871,19 +1871,30 @@ trait HasPartitionExpressions extends SQLConfHelper {
protected def partitioning: Partitioning = if (partitionExpressions.isEmpty) {
RoundRobinPartitioning(numPartitions)
} else {
val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder])
require(sortOrder.isEmpty || nonSortOrder.isEmpty,
s"${getClass.getSimpleName} expects that either all its `partitionExpressions` are of type " +
"`SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`, which " +
"means `HashPartitioning`. In this case we have:" +
s"""
|SortOrder: $sortOrder
|NonSortOrder: $nonSortOrder
""".stripMargin)
if (sortOrder.nonEmpty) {
RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions)
// Check if we have DirectShufflePartitionID expressions
val directShuffleExprs = partitionExpressions.filter(_.isInstanceOf[DirectShufflePartitionID])
if (directShuffleExprs.nonEmpty) {
require(directShuffleExprs.length == 1 && partitionExpressions.length == 1,
s"DirectShufflePartitionID can only be used as a single partition expression, " +
s"but found ${directShuffleExprs.length} DirectShufflePartitionID expressions " +
s"out of ${partitionExpressions.length} total expressions")
ShufflePartitionIdPassThrough(
partitionExpressions.head.asInstanceOf[DirectShufflePartitionID], numPartitions)
} else {
HashPartitioning(partitionExpressions, numPartitions)
val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder])
require(sortOrder.isEmpty || nonSortOrder.isEmpty,
s"${getClass.getSimpleName} expects that either all its `partitionExpressions` are of " +
"`SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`, which " +
"means `HashPartitioning`. In this case we have:" +
s"""
|SortOrder: $sortOrder
|NonSortOrder: $nonSortOrder
""".stripMargin)
if (sortOrder.nonEmpty) {
RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions)
} else {
HashPartitioning(partitionExpressions, numPartitions)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -946,3 +946,24 @@ case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
specs.head.numPartitions
}
}

/**
* Represents a partitioning where partition IDs are passed through directly from the
* DirectShufflePartitionID expression. This partitioning scheme is used when users
* want to directly control partition placement rather than using hash-based partitioning.
*
* This partitioning maps directly to the PartitionIdPassthrough RDD partitioner.
*/
case class ShufflePartitionIdPassThrough(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could creating this on a column with high cardinality lead to a sudden increase in partitions? Will subsequent AQE rules try to act and reduce the number of partitions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it will not reuse or remove shuffles. This is more to replace RDD's Partitioner API so people can completely migrate to DataFrame API. For the fact of performance and efficiency, it won't be super useful.

expr: DirectShufflePartitionID,
numPartitions: Int) extends HashPartitioningLike {

// TODO(SPARK-53401): Support Shuffle Spec in Direct Partition ID Pass Through
def partitionIdExpression: Expression = Pmod(expr, Literal(numPartitions))

override def expressions: Seq[Expression] = expr :: Nil

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): ShufflePartitionIdPassThrough =
copy(expr = newChildren.head.asInstanceOf[DirectShufflePartitionID])
}
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@ object ShuffleExchangeExec {
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
new PartitionIdPassthrough(n)
case ShufflePartitionIdPassThrough(_, n) =>
// For ShufflePartitionIdPassThrough, the DirectShufflePartitionID expression directly
// produces partition IDs, so we use PartitionIdPassthrough to pass them through directly.
new PartitionIdPassthrough(n)
case RangePartitioning(sortingExpressions, numPartitions) =>
// Extract only fields used for sorting to avoid collecting large fields that does not
// affect sorting result when deciding partition bounds in RangePartitioner
Expand Down Expand Up @@ -399,6 +403,10 @@ object ShuffleExchangeExec {
case SinglePartition => identity
case KeyGroupedPartitioning(expressions, _, _, _) =>
row => bindReferences(expressions, outputAttributes).map(_.eval(row))
case s: ShufflePartitionIdPassThrough =>
// For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID
val projection = UnsafeProjection.create(s.expressions, outputAttributes)
row => projection(row).getInt(0)
case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning")
}

Expand Down
45 changes: 45 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2785,6 +2785,51 @@ class DataFrameSuite extends QueryTest
val df1 = df.select("a").orderBy("b").orderBy("all")
checkAnswer(df1, Seq(Row(1), Row(4)))
}

test("SPARK-53401: direct_shuffle_partition_id - should partition rows to the specified " +
"partition ID") {
val numPartitions = 10
val df = spark.range(100).withColumn("p_id", col("id") % numPartitions)

val repartitioned = df.repartition(numPartitions, direct_shuffle_partition_id($"p_id"))
val result = repartitioned.withColumn("actual_p_id", spark_partition_id())

assert(result.filter(col("p_id") =!= col("actual_p_id")).count() == 0)

assert(result.rdd.getNumPartitions == numPartitions)
}

test("SPARK-53401: direct_shuffle_partition_id - should work with expr()") {
val numPartitions = 5
val df = spark.range(50).withColumn("p_id", col("id") % numPartitions)

val repartitioned = df.repartition(numPartitions, expr("direct_shuffle_partition_id(p_id)"))
val result = repartitioned.withColumn("actual_p_id", spark_partition_id())

assert(result.filter(col("p_id") =!= col("actual_p_id")).count() == 0)
}

test("SPARK-53401: direct_shuffle_partition_id - should fail when partition ID is null") {
val df = spark.range(10).withColumn("p_id",
when(col("id") < 5, col("id")).otherwise(lit(null).cast("long"))
)
val repartitioned = df.repartition(5, direct_shuffle_partition_id($"p_id"))

val e = intercept[SparkException] {
repartitioned.collect()
}
assert(e.getCause.isInstanceOf[IllegalArgumentException])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the actual error? if the error message is not clear we should do explicit null check, or simply treat null as partition id 0.

assert(e.getCause.getMessage.contains("The partition ID expression must not be null."))
}

test("SPARK-53401: direct_shuffle_partition_id - should fail analysis for non-integral types") {
val df = spark.range(5).withColumn("s", lit("a"))
val e = intercept[AnalysisException] {
df.repartition(5, direct_shuffle_partition_id($"s")).collect()
}
// Should fail with type error from DirectShufflePartitionID expression
assert(e.getMessage.contains("requires an integral type"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we throw this error now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Pmod.

The full error message is

org.scalatest.exceptions.TestFailedException: "Job aborted due to stage failure: Task 1 in stage 9.0 failed 1 times, most recent failure: Lost task 1.0 in stage 9.0 (TID 20) (192.168.1.72 executor driver): java.util.concurrent.ExecutionException: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 38, Column 71: Failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 38, Column 71: Binary numeric promotion not possible on types "org.apache.spark.unsafe.types.UTF8String" and "int"
	at com.google.common.util.concurrent.AbstractFuture.getDoneValue(AbstractFuture.java:604)
	at com.google.common.util.concurrent.AbstractFuture.get(AbstractFuture.java:559)
	at com.google.common.util.concurrent.AbstractFuture$TrustedFuture.get(AbstractFuture.java:114)
	at com.google.common.util.concurrent.Uninterruptibles.getUninterruptibly(Uninterruptibles.java:247)
	at com.google.common.cache.LocalCache$Segment.getAndRecordStats(LocalCache.java:2349)
	at com.google.common.cache.LocalCache$Segment.loadSync(LocalCache.java:2317)
	at com.google.common.cache.LocalCache$Segment.lockedGetOrLoad(LocalCache.java:2190)
	at com.google.common.cache.LocalCache$Segment.get(LocalCache.java:2080)
	at com.google.common.cache.LocalCache.get(LocalCache.java:4017)
	at com.google.common.cache.LocalCache.getOrLoad(LocalCache.java:4040)
	at com.google.common.cache.LocalCache$LocalLoadingCache.get(LocalCache.java:4989)
	at org.apache.spark.util.NonFateSharingLoadingCache.$anonfun$get$2(NonFateSharingCache.scala:108)
	at org.apache.spark.util.KeyLock.withLock(KeyLock.scala:64)
	at org.apache.spark.util.NonFateSharingLoadingCache.get(NonFateSharingCache.scala:108)
	at org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator$.compile(CodeGenerator.scala:1490)
	at org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection$.create(GenerateUnsafeProjection.scala:378)
	at org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection$.generate(GenerateUnsafeProjection.scala:327)
	at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.createCodeGeneratedObject(Projection.scala:125)
	at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.createCodeGeneratedObject(Projection.scala:121)
	at org.apache.spark.sql.catalyst.expressions.CodeGeneratorWithInterpretedFallback.createObject(CodeGeneratorWithInterpretedFallback.scala:45)
	at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:152)
	at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:162)
	at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.getPartitionKeyExtractor$1(ShuffleExchangeExec.scala:408)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh it's execution time. We should fail earlier in the analysis time, like CheckAnalysis

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added DirectShufflePartitionID back and it now throws error in the analysis time :)

[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] ······ The first parameter requires the "BIGINT" type, however "s" has the type "STRING"

}
}

case class GroupByKey(a: Int, b: Int)
Expand Down