Skip to content

Conversation

shujingyang-db
Copy link
Contributor

@shujingyang-db shujingyang-db commented Aug 28, 2025

What changes were proposed in this pull request?

Currently, Spark's DataFrame repartition() API only supports hash-based and range-based partitioning strategies. Users who need precise control over which partition each row goes to (similar to RDD's partitionBy with custom partitioners) have no direct way to achieve this at the DataFrame level.

This PR introduces a new DataFrame API, repartitionById(col, numPartitions), an API that allows users to directly specify target partition IDs in DataFrame repartitioning operations:

// Partition rows based on a computed partition ID
val df = spark.range(100).withColumn("partition_id", col("id") % 10)
val repartitioned = df.repartitionById($"partition_id", 10)

Why are the changes needed?

Enable precise control over which partition each row goes to (similar to RDD's partitionBy with custom partitioners) at the DataFrame level

Does this PR introduce any user-facing change?

Yes.

// Partition rows based on a computed partition ID
val df = spark.range(100).withColumn("partition_id", col("id") % 10)
val repartitioned = df.repartitionById($"partition_id", 10)

How was this patch tested?

New Unit Tests

Was this patch authored or co-authored using generative AI tooling?

No

@github-actions github-actions bot added the SQL label Aug 28, 2025
@HyukjinKwon HyukjinKwon changed the title [DRAFT][ SPARK-53401] Enable Direct Passthrough Partitioning in the DataFrame API [DRAFT][SPARK-53401] Enable Direct Passthrough Partitioning in the DataFrame API Aug 28, 2025
@@ -2045,6 +2045,19 @@ object functions {
*/
def spark_partition_id(): Column = Column.fn("spark_partition_id")

/**
* Returns the partition ID specified by the given column expression for direct shuffle
* partitioning. The input expression must evaluate to an integral type and must not be null.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this partition id be changed by AQE?

@shujingyang-db shujingyang-db marked this pull request as ready for review August 28, 2025 07:07
@shujingyang-db shujingyang-db changed the title [DRAFT][SPARK-53401] Enable Direct Passthrough Partitioning in the DataFrame API [SPARK-53401] Enable Direct Passthrough Partitioning in the DataFrame API Aug 28, 2025
*
* This partitioning maps directly to the PartitionIdPassthrough RDD partitioner.
*/
case class ShufflePartitionIdPassThrough(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could creating this on a column with high cardinality lead to a sudden increase in partitions? Will subsequent AQE rules try to act and reduce the number of partitions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it will not reuse or remove shuffles. This is more to replace RDD's Partitioner API so people can completely migrate to DataFrame API. For the fact of performance and efficiency, it won't be super useful.

* @group typedrel
* @since 4.1.0
*/
def repartitionById(partitionIdExpr: Column): Dataset[T] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's risky to provide a default numPartitions. Can we always ask users to specify numPartitions?

val e = intercept[SparkException] {
repartitioned.collect()
}
assert(e.getCause.isInstanceOf[IllegalArgumentException])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the actual error? if the error message is not clear we should do explicit null check, or simply treat null as partition id 0.

@@ -1406,6 +1406,87 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
assert(planned.exists(_.isInstanceOf[GlobalLimitExec]))
assert(planned.exists(_.isInstanceOf[LocalLimitExec]))
}

test("SPARK-53401: repartitionById should throw an exception for negative partition id") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, shall we use pmod then? then the partition id is always positive, see https://docs.databricks.com/aws/en/sql/language-manual/functions/pmod

assert(e.getMessage.contains("Index -5 out of bounds"))
}

test("SPARK-53401: repartitionById should throw an exception for partition id >= numPartitions") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, how can this happen if we do mod/pmod?

val df = spark.range(100).select($"id" % 10 as "key", $"id" as "value")
val grouped =
df.repartitionById(10, $"key")
.filter($"value" > 50).groupBy($"key").count()
Copy link
Contributor

@cloud-fan cloud-fan Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what this test proves is that Filter can propagate child's output partitioning, which is already proven by other tests and we don't need to verify it again here.

checkShuffleCount(grouped, 1)
}

test("SPARK-53401: shuffle reuse after a join that preserves partitioning") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a more interesting test is to prove that a join with id pass-through and hash partitioning will still do a shuffle on the id pass-through side.

@HyukjinKwon HyukjinKwon changed the title [SPARK-53401] Enable Direct Passthrough Partitioning in the DataFrame API [SPARK-53401][SQL] Enable Direct Passthrough Partitioning in the DataFrame API Aug 30, 2025
numPartitions: Int) extends Expression with Partitioning with Unevaluable {

// TODO(SPARK-53401): Support Shuffle Spec in Direct Partition ID Pass Through
def partitionIdExpression: Expression = Pmod(expr.child, Literal(numPartitions))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we strip DirectShufflePartitionID, which means we can make DirectShufflePartitionID extend Unevaluable.

@cloud-fan
Copy link
Contributor

cloud-fan commented Sep 3, 2025

@shujingyang-db please go through all the open review comments and make sure they are all addressed.

@github-actions github-actions bot removed the INFRA label Sep 3, 2025
@cloud-fan
Copy link
Contributor

thanks, merging to master!

@cloud-fan cloud-fan closed this in b017473 Sep 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants