Skip to content

Commit

Permalink
[SPARK-48037][CORE] Fix SortShuffleWriter lacks shuffle write related…
Browse files Browse the repository at this point in the history
… metrics resulting in potentially inaccurate data

### What changes were proposed in this pull request?
This PR aims to fix SortShuffleWriter lacks shuffle write related metrics resulting in potentially inaccurate data.

### Why are the changes needed?
When the shuffle writer is SortShuffleWriter, it does not use SQLShuffleWriteMetricsReporter to update metrics, which causes AQE to obtain runtime statistics and the rowCount obtained is 0.

Some optimization rules rely on rowCount statistics, such as `EliminateLimits`. Because rowCount is 0, it removes the limit operator. At this time, we get data results without limit.

https://github.com/apache/spark/blob/59d5946cfd377e9203ccf572deb34f87fab7510c/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala#L168-L172

https://github.com/apache/spark/blob/59d5946cfd377e9203ccf572deb34f87fab7510c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala#L2067-L2070

### Does this PR introduce _any_ user-facing change?
Yes

### How was this patch tested?
Production environment verification.

**master metrics**
<img width="296" alt="image" src="https://github.com/apache/spark/assets/3898450/dc9b6e8a-93ec-4f59-a903-71aa5b11962c">

**PR metrics**

<img width="276" alt="image" src="https://github.com/apache/spark/assets/3898450/2d73b773-2dcc-4d23-81de-25dcadac86c1">

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#46273 from cxzl25/SPARK-48037.

Authored-by: sychen <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
cxzl25 authored and JacobZheng0927 committed May 11, 2024
1 parent 9815de8 commit 7b9fc15
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
metrics,
shuffleExecutorComponents)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents)
new SortShuffleWriter(other, mapId, context, metrics, shuffleExecutorComponents)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.util.collection.ExternalSorter

private[spark] class SortShuffleWriter[K, V, C](
handle: BaseShuffleHandle[K, V, C],
mapId: Long,
context: TaskContext,
writeMetrics: ShuffleWriteMetricsReporter,
shuffleExecutorComponents: ShuffleExecutorComponents)
extends ShuffleWriter[K, V] with Logging {

Expand All @@ -46,8 +48,6 @@ private[spark] class SortShuffleWriter[K, V, C](

private var partitionLengths: Array[Long] = _

private val writeMetrics = context.taskMetrics().shuffleWriteMetrics

/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
Expand All @@ -67,7 +67,7 @@ private[spark] class SortShuffleWriter[K, V, C](
// (see SPARK-3570).
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter, writeMetrics)
partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer._
import org.apache.spark.shuffle.ShufflePartitionPairsWriter
import org.apache.spark.shuffle.{ShufflePartitionPairsWriter, ShuffleWriteMetricsReporter}
import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter}
import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}
Expand Down Expand Up @@ -693,7 +693,8 @@ private[spark] class ExternalSorter[K, V, C](
def writePartitionedMapOutput(
shuffleId: Int,
mapId: Long,
mapOutputWriter: ShuffleMapOutputWriter): Unit = {
mapOutputWriter: ShuffleMapOutputWriter,
writeMetrics: ShuffleWriteMetricsReporter): Unit = {
if (spills.isEmpty) {
// Case where we only have in-memory data
val collection = if (aggregator.isDefined) map else buffer
Expand All @@ -710,7 +711,7 @@ private[spark] class ExternalSorter[K, V, C](
serializerManager,
serInstance,
blockId,
context.taskMetrics().shuffleWriteMetrics,
writeMetrics,
if (partitionChecksums.nonEmpty) partitionChecksums(partitionId) else null)
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(partitionPairsWriter)
Expand All @@ -734,7 +735,7 @@ private[spark] class ExternalSorter[K, V, C](
serializerManager,
serInstance,
blockId,
context.taskMetrics().shuffleWriteMetrics,
writeMetrics,
if (partitionChecksums.nonEmpty) partitionChecksums(id) else null)
if (elements.hasNext) {
for (elem <- elements) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class SortShuffleWriterSuite
shuffleHandle,
mapId = 1,
context,
context.taskMetrics().shuffleWriteMetrics,
shuffleExecutorComponents)
writer.write(Iterator.empty)
writer.stop(success = true)
Expand All @@ -102,6 +103,7 @@ class SortShuffleWriterSuite
shuffleHandle,
mapId = 2,
context,
context.taskMetrics().shuffleWriteMetrics,
shuffleExecutorComponents)
writer.write(records.iterator)
writer.stop(success = true)
Expand Down Expand Up @@ -158,6 +160,7 @@ class SortShuffleWriterSuite
shuffleHandle,
mapId = 0,
context,
context.taskMetrics().shuffleWriteMetrics,
new LocalDiskShuffleExecutorComponents(
conf, shuffleBlockResolver._blockManager, shuffleBlockResolver))
writer.write(records.iterator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession {
assert(sorter.numSpills > 0)

// Merging spilled files should not throw assertion error
sorter.writePartitionedMapOutput(0, 0, mapOutputWriter)
sorter.writePartitionedMapOutput(0, 0, mapOutputWriter,
taskContext.taskMetrics.shuffleWriteMetrics)
}

test("SPARK-10403: unsafe row serializer with SortShuffleManager") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.{DataFrame, Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
Expand Down Expand Up @@ -2502,6 +2503,28 @@ class AdaptiveQueryExecSuite
}
}

test("SPARK-48037: Fix SortShuffleWriter lacks shuffle write related metrics " +
"resulting in potentially inaccurate data") {
withTable("t3") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.SHUFFLE_PARTITIONS.key -> (SortShuffleManager
.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1).toString) {
sql("CREATE TABLE t3 USING PARQUET AS SELECT id FROM range(2)")
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"""
|SELECT id, count(*)
|FROM t3
|GROUP BY id
|LIMIT 1
|""".stripMargin, skipCheckAnswer = true)
// The shuffle stage produces two rows and the limit operator should not been optimized out.
assert(findTopLevelLimit(plan).size == 1)
assert(findTopLevelLimit(adaptivePlan).size == 1)
}
}
}

test("SPARK-37063: OptimizeSkewInRebalancePartitions support optimize non-root node") {
withTempView("v") {
withSQLConf(
Expand Down

0 comments on commit 7b9fc15

Please sign in to comment.