Skip to content

Commit 5a5f5ad

Browse files
davidm-dbcloud-fan
andcommitted
[SPARK-52134] Move execution logic to SqlScriptingExecution and enable Spark Connect path
Move the script execution from `SparkSession#sql` to `QueryExecution#lazyAnalyzed`. This allows `QueryExecution` to receive the original parsed logical plan for scripting, which will be used to detect script execution in Spark Connect to treat them as commands. Moving the `executeSqlScript` logic from `SparkSession` to `SqlScriptingExecution's` object. SQL Scripting improvements. No. This PR enables new functionality though (execution through Spark Connect), but the results are remaining the same. Already existing tests confirm that refactor of execution logic doesn't affect anything. Test added to confirm that execution through Spark Connect is not failing. No. Closes apache#50895 from davidm-db/execute_sql_script_refactor. Lead-authored-by: David Milicevic <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit 46486cf) Signed-off-by: Wenchen Fan <[email protected]>
1 parent f683e23 commit 5a5f5ad

File tree

6 files changed

+164
-96
lines changed

6 files changed

+164
-96
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
5454
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
5555
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
5656
import org.apache.spark.sql.catalyst.plans.logical
57-
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateStarAction}
57+
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, CoGroup, CollectMetrics, CommandResult, CompoundBody, Deduplicate, DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateStarAction}
5858
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
5959
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern}
6060
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -2614,8 +2614,9 @@ class SparkConnectPlanner(
26142614
s"SQL command expects either a SQL or a WithRelations, but got $other")
26152615
}
26162616

2617-
// Check if commands have been executed.
2617+
// Check if command or SQL Script has been executed.
26182618
val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult]
2619+
val isSqlScript = df.queryExecution.logical.isInstanceOf[CompoundBody]
26192620
val rows = df.logicalPlan match {
26202621
case lr: LocalRelation => lr.data
26212622
case cr: CommandResult => cr.rows
@@ -2627,7 +2628,7 @@ class SparkConnectPlanner(
26272628
val result = SqlCommandResult.newBuilder()
26282629
// Only filled when isCommand
26292630
val metrics = ExecutePlanResponse.Metrics.newBuilder()
2630-
if (isCommand) {
2631+
if (isCommand || isSqlScript) {
26312632
// Convert the results to Arrow.
26322633
val schema = df.schema
26332634
val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,19 @@
1616
*/
1717
package org.apache.spark.sql.connect
1818

19+
import java.io.ByteArrayInputStream
1920
import java.util.{TimeZone, UUID}
2021

2122
import scala.reflect.runtime.universe.TypeTag
2223

2324
import org.apache.arrow.memory.RootAllocator
25+
import org.apache.arrow.vector.ipc.ArrowStreamReader
2426
import org.scalatest.concurrent.{Eventually, TimeLimits}
2527
import org.scalatest.time.Span
2628
import org.scalatest.time.SpanSugar._
2729

2830
import org.apache.spark.connect.proto
31+
import org.apache.spark.connect.proto.ExecutePlanResponse
2932
import org.apache.spark.sql.catalyst.ScalaReflection
3033
import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, RetryPolicy, SparkConnectClient, SparkConnectStubState}
3134
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
@@ -143,6 +146,21 @@ trait SparkConnectServerTest extends SharedSparkSession {
143146
proto.Plan.newBuilder().setRoot(dsl.sql(query)).build()
144147
}
145148

149+
protected def buildSqlCommandPlan(sqlCommand: String) = {
150+
proto.Plan
151+
.newBuilder()
152+
.setCommand(
153+
proto.Command
154+
.newBuilder()
155+
.setSqlCommand(
156+
proto.SqlCommand
157+
.newBuilder()
158+
.setSql(sqlCommand)
159+
.build())
160+
.build())
161+
.build()
162+
}
163+
146164
protected def buildLocalRelation[A <: Product: TypeTag](data: Seq[A]) = {
147165
val encoder = ScalaReflection.encoderFor[A]
148166
val arrowData =
@@ -305,4 +323,43 @@ trait SparkConnectServerTest extends SharedSparkSession {
305323
val plan = buildPlan(query)
306324
runQuery(plan, queryTimeout, iterSleep)
307325
}
326+
327+
protected def checkSqlCommandResponse(
328+
result: ExecutePlanResponse.SqlCommandResult,
329+
expected: Seq[Seq[Any]]): Unit = {
330+
// Extract the serialized Arrow data as a byte array.
331+
val dataBytes = result.getRelation.getLocalRelation.getData.toByteArray
332+
333+
// Create an ArrowStreamReader to deserialize the data.
334+
val allocator = new RootAllocator(Long.MaxValue)
335+
val inputStream = new ByteArrayInputStream(dataBytes)
336+
val reader = new ArrowStreamReader(inputStream, allocator)
337+
338+
try {
339+
// Read the schema and data.
340+
val root = reader.getVectorSchemaRoot
341+
// Load the first batch of data.
342+
reader.loadNextBatch()
343+
344+
// Get dimensions.
345+
val rowCount = root.getRowCount
346+
val colCount = root.getFieldVectors.size
347+
assert(rowCount == expected.length, "Row count mismatch")
348+
assert(colCount == expected.head.length, "Column count mismatch")
349+
350+
// Compare to expected.
351+
for (i <- 0 until rowCount) {
352+
for (j <- 0 until colCount) {
353+
val col = root.getFieldVectors.get(j)
354+
val value = col.getObject(i)
355+
print(value)
356+
assert(value == expected(i)(j), s"Value mismatch at ($i, $j)")
357+
}
358+
}
359+
} finally {
360+
// Clean up resources.
361+
reader.close()
362+
allocator.close()
363+
}
364+
}
308365
}

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,27 @@ class SparkConnectServiceE2ESuite extends SparkConnectServerTest {
3333
// were all already in the buffer.
3434
val BIG_ENOUGH_QUERY = "select * from range(1000000)"
3535

36+
test("SQL Script over Spark Connect.") {
37+
val sessionId = UUID.randomUUID.toString()
38+
val userId = "ScriptUser"
39+
val sqlScriptText =
40+
"""BEGIN
41+
|IF 1 = 1 THEN
42+
| SELECT 1;
43+
|ELSE
44+
| SELECT 2;
45+
|END IF;
46+
|END
47+
""".stripMargin
48+
withClient(sessionId = sessionId, userId = userId) { client =>
49+
// this will create the session, and then ReleaseSession at the end of withClient.
50+
val enableSqlScripting = client.execute(buildPlan("SET spark.sql.scripting.enabled=true"))
51+
enableSqlScripting.hasNext // trigger execution
52+
val query = client.execute(buildSqlCommandPlan(sqlScriptText))
53+
checkSqlCommandResponse(query.next().getSqlCommandResult, Seq(Seq(1)))
54+
}
55+
}
56+
3657
test("Execute is sent eagerly to the server upon iterator creation") {
3758
// This behavior changed with grpc upgrade from 1.56.0 to 1.59.0.
3859
// Testing to be aware of future changes.

sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala

Lines changed: 16 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,9 @@ import org.apache.spark.sql.artifact.ArtifactManager
4242
import org.apache.spark.sql.catalyst._
4343
import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation}
4444
import org.apache.spark.sql.catalyst.encoders._
45-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
45+
import org.apache.spark.sql.catalyst.expressions.AttributeReference
4646
import org.apache.spark.sql.catalyst.parser.ParserInterface
47-
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, LogicalPlan, Range}
48-
import org.apache.spark.sql.catalyst.types.DataTypeUtils
47+
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, Range}
4948
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
5049
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
5150
import org.apache.spark.sql.classic.SparkSession.applyAndLoadExtensions
@@ -56,7 +55,6 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
5655
import org.apache.spark.sql.functions.lit
5756
import org.apache.spark.sql.internal._
5857
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
59-
import org.apache.spark.sql.scripting.SqlScriptingExecution
6058
import org.apache.spark.sql.sources.BaseRelation
6159
import org.apache.spark.sql.types.{DataType, StructType}
6260
import org.apache.spark.sql.util.ExecutionListenerManager
@@ -432,50 +430,6 @@ class SparkSession private(
432430
| Everything else |
433431
* ----------------- */
434432

435-
/**
436-
* Executes given script and return the result of the last statement.
437-
* If script contains no queries, an empty `DataFrame` is returned.
438-
*
439-
* @param script A SQL script to execute.
440-
* @param args A map of parameter names to SQL literal expressions.
441-
*
442-
* @return The result as a `DataFrame`.
443-
*/
444-
private def executeSqlScript(
445-
script: CompoundBody,
446-
args: Map[String, Expression] = Map.empty): DataFrame = {
447-
val sse = new SqlScriptingExecution(script, this, args)
448-
sse.withLocalVariableManager {
449-
var result: Option[Seq[Row]] = None
450-
451-
// We must execute returned df before calling sse.getNextResult again because sse.hasNext
452-
// advances the script execution and executes all statements until the next result. We must
453-
// collect results immediately to maintain execution order.
454-
// This ensures we respect the contract of SqlScriptingExecution API.
455-
var df: Option[DataFrame] = sse.getNextResult
456-
var resultSchema: Option[StructType] = None
457-
while (df.isDefined) {
458-
sse.withErrorHandling {
459-
// Collect results from the current DataFrame.
460-
result = Some(df.get.collect().toSeq)
461-
resultSchema = Some(df.get.schema)
462-
}
463-
df = sse.getNextResult
464-
}
465-
466-
if (result.isEmpty) {
467-
emptyDataFrame
468-
} else {
469-
// If `result` is defined, then `resultSchema` must be defined as well.
470-
assert(resultSchema.isDefined)
471-
472-
val attributes = DataTypeUtils.toAttributes(resultSchema.get)
473-
Dataset.ofRows(
474-
self, LocalRelation.fromExternalRows(attributes, result.get))
475-
}
476-
}
477-
}
478-
479433
/**
480434
* Executes a SQL query substituting positional parameters by the given arguments,
481435
* returning the result as a `DataFrame`.
@@ -495,30 +449,17 @@ class SparkSession private(
495449
withActive {
496450
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
497451
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
498-
parsedPlan match {
499-
case compoundBody: CompoundBody =>
500-
if (args.nonEmpty) {
501-
// Positional parameters are not supported for SQL scripting.
502-
throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
503-
}
504-
compoundBody
505-
case logicalPlan: LogicalPlan =>
506-
if (args.nonEmpty) {
507-
PosParameterizedQuery(logicalPlan, args.map(lit(_).expr).toImmutableArraySeq)
508-
} else {
509-
logicalPlan
510-
}
452+
if (args.nonEmpty) {
453+
if (parsedPlan.isInstanceOf[CompoundBody]) {
454+
// Positional parameters are not supported for SQL scripting.
455+
throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
456+
}
457+
PosParameterizedQuery(parsedPlan, args.map(lit(_).expr).toImmutableArraySeq)
458+
} else {
459+
parsedPlan
511460
}
512461
}
513-
514-
plan match {
515-
case compoundBody: CompoundBody =>
516-
// Execute the SQL script.
517-
executeSqlScript(compoundBody)
518-
case logicalPlan: LogicalPlan =>
519-
// Execute the standalone SQL statement.
520-
Dataset.ofRows(self, plan, tracker)
521-
}
462+
Dataset.ofRows(self, plan, tracker)
522463
}
523464

524465
/** @inheritdoc */
@@ -549,26 +490,13 @@ class SparkSession private(
549490
withActive {
550491
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
551492
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
552-
parsedPlan match {
553-
case compoundBody: CompoundBody =>
554-
compoundBody
555-
case logicalPlan: LogicalPlan =>
556-
if (args.nonEmpty) {
557-
NameParameterizedQuery(logicalPlan, args.transform((_, v) => lit(v).expr))
558-
} else {
559-
logicalPlan
560-
}
493+
if (args.nonEmpty) {
494+
NameParameterizedQuery(parsedPlan, args.transform((_, v) => lit(v).expr))
495+
} else {
496+
parsedPlan
561497
}
562498
}
563-
564-
plan match {
565-
case compoundBody: CompoundBody =>
566-
// Execute the SQL script.
567-
executeSqlScript(compoundBody, args.transform((_, v) => lit(v).expr))
568-
case logicalPlan: LogicalPlan =>
569-
// Execute the standalone SQL statement.
570-
Dataset.ofRows(self, plan, tracker)
571-
}
499+
Dataset.ofRows(self, plan, tracker)
572500
}
573501

574502
/** @inheritdoc */

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row}
3333
import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker}
34-
import org.apache.spark.sql.catalyst.analysis.{LazyExpression, UnsupportedOperationChecker}
34+
import org.apache.spark.sql.catalyst.analysis.{LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker}
3535
import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats
3636
import org.apache.spark.sql.catalyst.plans.QueryPlan
37-
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union}
37+
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union}
3838
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
3939
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
4040
import org.apache.spark.sql.catalyst.util.truncatedString
@@ -46,6 +46,7 @@ import org.apache.spark.sql.execution.exchange.EnsureRequirements
4646
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
4747
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, WatermarkPropagator}
4848
import org.apache.spark.sql.internal.SQLConf
49+
import org.apache.spark.sql.scripting.SqlScriptingExecution
4950
import org.apache.spark.sql.streaming.OutputMode
5051
import org.apache.spark.util.{LazyTry, Utils}
5152
import org.apache.spark.util.ArrayImplicits._
@@ -93,16 +94,26 @@ class QueryExecution(
9394
}
9495

9596
private val lazyAnalyzed = LazyTry {
97+
val withScriptExecuted = logical match {
98+
// Execute the SQL script. Script doesn't need to go through the analyzer as Spark will run
99+
// each statement as individual query.
100+
case NameParameterizedQuery(compoundBody: CompoundBody, argNames, argValues) =>
101+
val args = argNames.zip(argValues).toMap
102+
SqlScriptingExecution.executeSqlScript(sparkSession, compoundBody, args)
103+
case compoundBody: CompoundBody =>
104+
SqlScriptingExecution.executeSqlScript(sparkSession, compoundBody)
105+
case _ => logical
106+
}
96107
try {
97108
val plan = executePhase(QueryPlanningTracker.ANALYSIS) {
98109
// We can't clone `logical` here, which will reset the `_analyzed` flag.
99-
sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker)
110+
sparkSession.sessionState.analyzer.executeAndCheck(withScriptExecuted, tracker)
100111
}
101112
tracker.setAnalyzed(plan)
102113
plan
103114
} catch {
104115
case NonFatal(e) =>
105-
tracker.setAnalysisFailed(logical)
116+
tracker.setAnalysisFailed(withScriptExecuted)
106117
throw e
107118
}
108119
}

0 commit comments

Comments
 (0)