Skip to content

Commit 0c3dacb

Browse files
jackywang-dbJiaqiWang18
authored andcommitted
[SPARK-53207][SDP] Send Pipeline Event to Client Asynchronously
### What changes were proposed in this pull request? Created `PipelineEventSender` to allow sending pipeline event back to client in a background thread to not block pipeline execution. New events gets queued into the sender and will get processed sequentially. The sender waits until all events are sent back to client before shutting down the processing loop. Implemented logical capacity checking that examines event types before submission, ensuring RunProgress and terminal FlowProgress events are always queued while other events may be dropped when the queue is full. This prevents buffer from overflowing and impact pipeline execution when the number of events are too large. The queue size can be controlled by a spark conf `spark.sql.connect.pipeline.event.queue.capacity`, currently defaulted to `1000` ### Why are the changes needed? To ensure event sending do not block pipeline execution. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New and existing tests to ensure events are delivered and not done in a blocking way. ### Was this patch authored or co-authored using generative AI tooling? No Closes #51956 from JiaqiWang18/SPARK-53207-async-event-response. Lead-authored-by: Jacky Wang <[email protected]> Co-authored-by: Jacky Wang <[email protected]> Signed-off-by: Gengliang Wang <[email protected]>
1 parent a5f76ca commit 0c3dacb

File tree

5 files changed

+473
-68
lines changed

5 files changed

+473
-68
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6207,6 +6207,16 @@ object SQLConf {
62076207
.createWithDefault(2)
62086208
}
62096209

6210+
val PIPELINES_EVENT_QUEUE_CAPACITY = {
6211+
buildConf("spark.sql.pipelines.event.queue.capacity")
6212+
.doc("Capacity of the event queue used in pipelined execution. When the queue is full, " +
6213+
"non-terminal FlowProgressEvents will be dropped.")
6214+
.version("4.1.0")
6215+
.intConf
6216+
.checkValue(v => v > 0, "Event queue capacity must be positive.")
6217+
.createWithDefault(1000)
6218+
}
6219+
62106220
val HADOOP_LINE_RECORD_READER_ENABLED =
62116221
buildConf("spark.sql.execution.datasources.hadoopLineRecordReader.enabled")
62126222
.internal()
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connect.pipelines
19+
20+
import java.util.concurrent.ThreadPoolExecutor
21+
import java.util.concurrent.atomic.AtomicBoolean
22+
23+
import scala.util.control.NonFatal
24+
25+
import com.google.protobuf.{Timestamp => ProtoTimestamp}
26+
import io.grpc.stub.StreamObserver
27+
28+
import org.apache.spark.connect.proto
29+
import org.apache.spark.connect.proto.ExecutePlanResponse
30+
import org.apache.spark.internal.{Logging, LogKeys}
31+
import org.apache.spark.sql.connect.service.SessionHolder
32+
import org.apache.spark.sql.internal.SQLConf
33+
import org.apache.spark.sql.pipelines.common.FlowStatus
34+
import org.apache.spark.sql.pipelines.logging.{FlowProgress, PipelineEvent, RunProgress}
35+
import org.apache.spark.util.ThreadUtils
36+
37+
/**
38+
* Handles sending pipeline events to the client in a background thread. This prevents pipeline
39+
* execution from blocking on streaming events.
40+
*/
41+
class PipelineEventSender(
42+
responseObserver: StreamObserver[ExecutePlanResponse],
43+
sessionHolder: SessionHolder)
44+
extends Logging
45+
with AutoCloseable {
46+
47+
private final val queueCapacity: Int =
48+
sessionHolder.session.conf
49+
.get(SQLConf.PIPELINES_EVENT_QUEUE_CAPACITY.key)
50+
.toInt
51+
52+
// ExecutorService for background event processing
53+
private val executor: ThreadPoolExecutor =
54+
ThreadUtils.newDaemonSingleThreadExecutor(threadName =
55+
s"PipelineEventSender-${sessionHolder.sessionId}")
56+
57+
/*
58+
* Atomic flags to track the state of the sender
59+
* - `isShutdown`: Indicates if the sender has been shut down, if true, no new events
60+
* can be accepted, and the executor will be shut down after processing all submitted events.
61+
*/
62+
private val isShutdown = new AtomicBoolean(false)
63+
64+
/**
65+
* Send an event async by submitting it to the executor, if the sender is not shut down.
66+
* Otherwise, throws an IllegalStateException, to raise awareness of the shutdown state.
67+
*
68+
* For RunProgress events, we ensure they are always queued even if the queue is full. For other
69+
* events, we may drop them if the queue is at capacity to prevent blocking.
70+
*/
71+
def sendEvent(event: PipelineEvent): Unit = synchronized {
72+
if (!isShutdown.get()) {
73+
if (shouldEnqueueEvent(event)) {
74+
executor.submit(new Runnable {
75+
override def run(): Unit = {
76+
try {
77+
sendEventToClient(event)
78+
} catch {
79+
case NonFatal(e) =>
80+
logError(
81+
log"Failed to send pipeline event to client: " +
82+
log"${MDC(LogKeys.ERROR, event.message)}",
83+
e)
84+
}
85+
}
86+
})
87+
}
88+
} else {
89+
throw new IllegalStateException(
90+
s"Cannot send event after shutdown for session ${sessionHolder.sessionId}")
91+
}
92+
}
93+
94+
private def shouldEnqueueEvent(event: PipelineEvent): Boolean = {
95+
event.details match {
96+
case _: RunProgress =>
97+
// For RunProgress events, always enqueue event
98+
true
99+
case flowProgress: FlowProgress if FlowStatus.isTerminal(flowProgress.status) =>
100+
// For FlowProgress events that are terminal, always enqueue event
101+
true
102+
case _ =>
103+
// For other events, check if we have capacity
104+
executor.getQueue.size() < queueCapacity
105+
}
106+
}
107+
108+
// Implementing AutoCloseable to allow for try-with-resources usage
109+
// This will ensure that the sender is properly shut down and all resources are released
110+
// without requiring explicit shutdown calls in user code.
111+
override def close(): Unit = shutdown()
112+
113+
/**
114+
* Shutdown the event sender, stop taking new events and wait for processing to complete. This
115+
* method blocks until all queued events have been processed. Idempotent operation: calling this
116+
* multiple times has no effect after the first call.
117+
*/
118+
def shutdown(): Unit = {
119+
if (isShutdown.compareAndSet(false, true)) {
120+
// Request a shutdown of the executor which waits for all tasks to complete
121+
executor.shutdown()
122+
// Blocks until all tasks have completed execution after a shutdown request,
123+
// disregard the timeout since we want all events to be processed
124+
if (!executor.awaitTermination(Long.MaxValue, java.util.concurrent.TimeUnit.MILLISECONDS)) {
125+
logError(
126+
log"Pipeline event sender for session " +
127+
log"${MDC(LogKeys.SESSION_ID, sessionHolder.sessionId)} failed to terminate")
128+
executor.shutdownNow()
129+
}
130+
logInfo(
131+
log"Pipeline event sender shutdown completed for session " +
132+
log"${MDC(LogKeys.SESSION_ID, sessionHolder.sessionId)}")
133+
}
134+
}
135+
136+
/**
137+
* Send a single event to the client
138+
*/
139+
private[connect] def sendEventToClient(event: PipelineEvent): Unit = {
140+
try {
141+
val protoEvent = constructProtoEvent(event)
142+
responseObserver.onNext(
143+
proto.ExecutePlanResponse
144+
.newBuilder()
145+
.setSessionId(sessionHolder.sessionId)
146+
.setServerSideSessionId(sessionHolder.serverSessionId)
147+
.setPipelineEventResult(proto.PipelineEventResult.newBuilder
148+
.setEvent(protoEvent)
149+
.build())
150+
.build())
151+
} catch {
152+
case NonFatal(e) =>
153+
logError(
154+
log"Failed to send pipeline event to client: " +
155+
log"${MDC(LogKeys.ERROR, event.message)}",
156+
e)
157+
}
158+
}
159+
160+
private def constructProtoEvent(event: PipelineEvent): proto.PipelineEvent = {
161+
val message = if (event.error.nonEmpty) {
162+
// Returns the message associated with a Throwable and all its causes
163+
def getExceptionMessages(throwable: Throwable): Seq[String] = {
164+
throwable.getMessage +:
165+
Option(throwable.getCause).map(getExceptionMessages).getOrElse(Nil)
166+
}
167+
val errorMessages = getExceptionMessages(event.error.get)
168+
s"""${event.message}
169+
|Error: ${errorMessages.mkString("\n")}""".stripMargin
170+
} else {
171+
event.message
172+
}
173+
val protoEventBuilder = proto.PipelineEvent
174+
.newBuilder()
175+
.setTimestamp(
176+
ProtoTimestamp
177+
.newBuilder()
178+
// java.sql.Timestamp normalizes its internal fields: getTime() returns
179+
// the full timestamp in milliseconds, while getNanos() returns the
180+
// fractional seconds (0-999,999,999 ns). This ensures no precision is
181+
// lost or double-counted.
182+
.setSeconds(event.timestamp.getTime / 1000)
183+
.setNanos(event.timestamp.getNanos)
184+
.build())
185+
.setMessage(message)
186+
protoEventBuilder.build()
187+
}
188+
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala

Lines changed: 38 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
package org.apache.spark.sql.connect.pipelines
1919

2020
import scala.jdk.CollectionConverters._
21+
import scala.util.Using
2122

22-
import com.google.protobuf.{Timestamp => ProtoTimestamp}
2323
import io.grpc.stub.StreamObserver
2424

2525
import org.apache.spark.connect.proto
@@ -237,77 +237,47 @@ private[connect] object PipelinesHandler extends Logging {
237237
sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
238238
val tableFiltersResult = createTableFilters(cmd, graphElementRegistry, sessionHolder)
239239

240-
// We will use this variable to store the run failure event if it occurs. This will be set
241-
// by the event callback.
242-
@volatile var runFailureEvent = Option.empty[PipelineEvent]
243-
// Define a callback which will stream logs back to the SparkConnect client when an internal
244-
// pipeline event is emitted during pipeline execution. We choose to pass a callback rather the
245-
// responseObserver to the pipelines execution code so that the pipelines module does not need
246-
// to take a dependency on SparkConnect.
247-
val eventCallback = { event: PipelineEvent =>
248-
val message = if (event.error.nonEmpty) {
249-
// Returns the message associated with a Throwable and all its causes
250-
def getExceptionMessages(throwable: Throwable): Seq[String] = {
251-
throwable.getMessage +:
252-
Option(throwable.getCause).map(getExceptionMessages).getOrElse(Nil)
240+
// Use the PipelineEventSender to send events back to the client asynchronously.
241+
Using.resource(new PipelineEventSender(responseObserver, sessionHolder)) { eventSender =>
242+
// We will use this variable to store the run failure event if it occurs. This will be set
243+
// by the event callback.
244+
@volatile var runFailureEvent = Option.empty[PipelineEvent]
245+
// Define a callback which will stream logs back to the SparkConnect client when an internal
246+
// pipeline event is emitted during pipeline execution. We choose to pass a callback rather
247+
// the responseObserver to the pipelines execution code so that the pipelines module does not
248+
// need to take a dependency on SparkConnect.
249+
val eventCallback = { event: PipelineEvent =>
250+
event.details match {
251+
// Failed runs are recorded in the event log. We do not pass these to the SparkConnect
252+
// client since the failed run will already result in an unhandled exception that is
253+
// propagated to the SparkConnect client. This special handling ensures that the client
254+
// does not see the same error twice for a failed run.
255+
case RunProgress(state) if state == FAILED => runFailureEvent = Some(event)
256+
case RunProgress(state) if state == CANCELED =>
257+
throw new RuntimeException("Pipeline run was canceled.")
258+
case _ =>
259+
eventSender.sendEvent(event)
253260
}
254-
val errorMessages = getExceptionMessages(event.error.get)
255-
s"""${event.message}
256-
|Error: ${errorMessages.mkString("\n")}""".stripMargin
257-
} else {
258-
event.message
259261
}
260-
event.details match {
261-
// Failed runs are recorded in the event log. We do not pass these to the SparkConnect
262-
// client since the failed run will already result in an unhandled exception that is
263-
// propagated to the SparkConnect client. This special handling ensures that the client
264-
// does not see the same error twice for a failed run.
265-
case RunProgress(state) if state == FAILED => runFailureEvent = Some(event)
266-
case RunProgress(state) if state == CANCELED =>
267-
throw new RuntimeException("Pipeline run was canceled.")
268-
case _ =>
269-
responseObserver.onNext(
270-
proto.ExecutePlanResponse
271-
.newBuilder()
272-
.setSessionId(sessionHolder.sessionId)
273-
.setServerSideSessionId(sessionHolder.serverSessionId)
274-
.setPipelineEventResult(
275-
proto.PipelineEventResult.newBuilder
276-
.setEvent(
277-
proto.PipelineEvent
278-
.newBuilder()
279-
.setTimestamp(
280-
ProtoTimestamp
281-
.newBuilder()
282-
// java.sql.Timestamp normalizes its internal fields: getTime() returns
283-
// the full timestamp in milliseconds, while getNanos() returns the
284-
// fractional seconds (0-999,999,999 ns). This ensures no precision is
285-
// lost or double-counted.
286-
.setSeconds(event.timestamp.getTime / 1000)
287-
.setNanos(event.timestamp.getNanos)
288-
.build())
289-
.setMessage(message)
290-
.build())
291-
.build())
292-
.build())
262+
263+
val pipelineUpdateContext = new PipelineUpdateContextImpl(
264+
graphElementRegistry.toDataflowGraph,
265+
eventCallback,
266+
tableFiltersResult.refresh,
267+
tableFiltersResult.fullRefresh)
268+
sessionHolder.cachePipelineExecution(dataflowGraphId, pipelineUpdateContext)
269+
270+
if (cmd.getDry) {
271+
pipelineUpdateContext.pipelineExecution.dryRunPipeline()
272+
} else {
273+
pipelineUpdateContext.pipelineExecution.runPipeline()
293274
}
294-
}
295-
val pipelineUpdateContext = new PipelineUpdateContextImpl(
296-
graphElementRegistry.toDataflowGraph,
297-
eventCallback,
298-
tableFiltersResult.refresh,
299-
tableFiltersResult.fullRefresh)
300-
sessionHolder.cachePipelineExecution(dataflowGraphId, pipelineUpdateContext)
301-
if (cmd.getDry) {
302-
pipelineUpdateContext.pipelineExecution.dryRunPipeline()
303-
} else {
304-
pipelineUpdateContext.pipelineExecution.runPipeline()
305-
}
306275

307-
// Rethrow any exceptions that caused the pipeline run to fail so that the exception is
308-
// propagated back to the SC client / CLI.
309-
runFailureEvent.foreach { event =>
310-
throw event.error.get
276+
// Rethrow any exceptions that caused the pipeline run to fail so that the exception is
277+
// propagated back to the SC client / CLI.
278+
runFailureEvent.foreach { event =>
279+
throw event.error.get
280+
}
311281
}
312282
}
313283

0 commit comments

Comments
 (0)