Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48037][CORE] Fix SortShuffleWriter lacks shuffle write related metrics resulting in potentially inaccurate data #46273

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
metrics,
shuffleExecutorComponents)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents)
new SortShuffleWriter(other, mapId, context, metrics, shuffleExecutorComponents)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.util.collection.ExternalSorter

private[spark] class SortShuffleWriter[K, V, C](
handle: BaseShuffleHandle[K, V, C],
mapId: Long,
context: TaskContext,
writeMetrics: ShuffleWriteMetricsReporter,
shuffleExecutorComponents: ShuffleExecutorComponents)
extends ShuffleWriter[K, V] with Logging {

Expand All @@ -46,8 +48,6 @@ private[spark] class SortShuffleWriter[K, V, C](

private var partitionLengths: Array[Long] = _

private val writeMetrics = context.taskMetrics().shuffleWriteMetrics

/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
Expand All @@ -67,7 +67,7 @@ private[spark] class SortShuffleWriter[K, V, C](
// (see SPARK-3570).
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter, writeMetrics)
partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer._
import org.apache.spark.shuffle.ShufflePartitionPairsWriter
import org.apache.spark.shuffle.{ShufflePartitionPairsWriter, ShuffleWriteMetricsReporter}
import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter}
import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}
Expand Down Expand Up @@ -693,7 +693,8 @@ private[spark] class ExternalSorter[K, V, C](
def writePartitionedMapOutput(
shuffleId: Int,
mapId: Long,
mapOutputWriter: ShuffleMapOutputWriter): Unit = {
mapOutputWriter: ShuffleMapOutputWriter,
writeMetrics: ShuffleWriteMetricsReporter): Unit = {
if (spills.isEmpty) {
// Case where we only have in-memory data
val collection = if (aggregator.isDefined) map else buffer
Expand All @@ -710,7 +711,7 @@ private[spark] class ExternalSorter[K, V, C](
serializerManager,
serInstance,
blockId,
context.taskMetrics().shuffleWriteMetrics,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the metrics given at getWriter looks like also coming from context.taskMetrics().shuffleWriteMetrics, isn't it looking the same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be SQLShuffleWriteMetricsReporter, which may not be the same as context.taskMetrics().shuffleWriteMetrics.

writer = manager.getWriter[Any, Any](
dep.shuffleHandle,
mapId,
context,
createMetricsReporter(context))
writer.write(inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]])

def createShuffleWriteProcessor(metrics: Map[String, SQLMetric]): ShuffleWriteProcessor = {
new ShuffleWriteProcessor {
override protected def createMetricsReporter(
context: TaskContext): ShuffleWriteMetricsReporter = {
new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics)
}
}

writeMetrics,
if (partitionChecksums.nonEmpty) partitionChecksums(partitionId) else null)
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(partitionPairsWriter)
Expand All @@ -734,7 +735,7 @@ private[spark] class ExternalSorter[K, V, C](
serializerManager,
serInstance,
blockId,
context.taskMetrics().shuffleWriteMetrics,
writeMetrics,
if (partitionChecksums.nonEmpty) partitionChecksums(id) else null)
if (elements.hasNext) {
for (elem <- elements) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class SortShuffleWriterSuite
shuffleHandle,
mapId = 1,
context,
context.taskMetrics().shuffleWriteMetrics,
shuffleExecutorComponents)
writer.write(Iterator.empty)
writer.stop(success = true)
Expand All @@ -102,6 +103,7 @@ class SortShuffleWriterSuite
shuffleHandle,
mapId = 2,
context,
context.taskMetrics().shuffleWriteMetrics,
shuffleExecutorComponents)
writer.write(records.iterator)
writer.stop(success = true)
Expand Down Expand Up @@ -158,6 +160,7 @@ class SortShuffleWriterSuite
shuffleHandle,
mapId = 0,
context,
context.taskMetrics().shuffleWriteMetrics,
new LocalDiskShuffleExecutorComponents(
conf, shuffleBlockResolver._blockManager, shuffleBlockResolver))
writer.write(records.iterator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession {
assert(sorter.numSpills > 0)

// Merging spilled files should not throw assertion error
sorter.writePartitionedMapOutput(0, 0, mapOutputWriter)
sorter.writePartitionedMapOutput(0, 0, mapOutputWriter,
taskContext.taskMetrics.shuffleWriteMetrics)
}

test("SPARK-10403: unsafe row serializer with SortShuffleManager") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2502,6 +2502,26 @@ class AdaptiveQueryExecSuite
}
}

test("SPARK-48037: Fix SortShuffleWriter lacks shuffle write related metrics " +
"resulting in potentially inaccurate data") {
withTable("t3") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.SHUFFLE_PARTITIONS.key -> "16777217") {
cxzl25 marked this conversation as resolved.
Show resolved Hide resolved
sql("CREATE TABLE t3 USING PARQUET AS SELECT id FROM range(2)")
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"""
|SELECT id, count(*)
|FROM t3
|GROUP BY id
|LIMIT 1
|""".stripMargin, skipCheckAnswer = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to skip checking the answer? I think the query can be LIMIT 10 so that the result is deterministic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because AQE is disabled when checking the results, it will have many partitions, and the order is not guaranteed at this time. When AQE is enabled, it will be merged into one partition.

set spark.sql.adaptive.enabled=false;
set spark.sql.shuffle.partitions=1000;
create table foo as select id from range(2);
select id, count(*) from foo group by id limit 1;

output

1	1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK the checkAnswer util will sort the data before comparison.

Copy link
Contributor Author

@cxzl25 cxzl25 May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does sort, but the final result is sorted locally. Because the result of the limit is uncertain, the local sort has no effect.

def getErrorMessageInCheckAnswer(
df: DataFrame,
expectedAnswer: Seq[Row],
checkToRDD: Boolean = true): Option[String] = {
val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
if (checkToRDD) {
SQLExecution.withSQLConfPropagated(df.sparkSession) {
df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791]
}
}
val sparkAnswer = try df.collect().toSeq catch {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can make it certain if the limit is larger than the number of result rows?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this test case, because spark.sql.adaptive.enabled=false, partitions will not be merged.
It has a large number of partitions and tasks, so it requires a large amount of driver memory to execute successfully.

assert(findTopLevelLimit(plan).size == 1)
cxzl25 marked this conversation as resolved.
Show resolved Hide resolved
assert(findTopLevelLimit(adaptivePlan).size == 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this only verifies if there is specific operator (i.e., Limit) in the query plan, how is it related to the metrics you want to fix?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. It is due to AQE usage of runtime metrics.

}
}
}

test("SPARK-37063: OptimizeSkewInRebalancePartitions support optimize non-root node") {
withTempView("v") {
withSQLConf(
Expand Down