From 1c14d5644a1d5e317a58e36625d2bdd08f75e9f8 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Mon, 17 Nov 2025 11:14:00 -0800 Subject: [PATCH 1/2] Refactor PythonException so it can take errorClass with sqlstate --- .../resources/error/error-conditions.json | 6 ++++ .../apache/spark/api/python/PythonRDD.scala | 30 +++++++++++++++++-- .../spark/api/python/PythonRunner.scala | 5 +++- .../planner/StreamingForeachBatchHelper.scala | 8 +++-- .../StreamingQueryListenerHelper.scala | 7 +++-- 5 files changed, 47 insertions(+), 9 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 1a69503aea4e..a5f19ec4b4ec 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5099,6 +5099,12 @@ ], "sqlState" : "38000" }, + "PYTHON_EXCEPTION" : { + "message": [ + "Exception raised from Python worker: " + ], + "sqlState" : "38000" + }, "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR" : { "message" : [ "Failed when Python streaming data source perform : " diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index cf0169fed60c..029d9fad9d08 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -126,8 +126,34 @@ private[spark] case class SimplePythonFunction( private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) /** Thrown for exceptions in user Python code. */ -private[spark] class PythonException(msg: String, cause: Throwable) - extends RuntimeException(msg, cause) +private[spark] class PythonException( + msg: String, + cause: Option[Throwable], + errorClass: Option[String], + messageParameters: Map[String, String], + context: Array[QueryContext]) + extends RuntimeException(msg, cause.orNull) with SparkThrowable { + + def this( + errorClass: String, + messageParameters: Map[String, String], + cause: Throwable = null, + context: Array[QueryContext] = Array.empty, + summary: String = "") = { + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), + Option(cause), + Option(errorClass), + messageParameters, + context + ) + } + + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava + + override def getCondition: String = errorClass.orNull + override def getQueryContext: Array[QueryContext] = context +} /** * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 66e204fee44b..3cce99d1db36 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -636,7 +636,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( protected def handlePythonException(): PythonException = { // Signals that an exception has been thrown in python val msg = PythonWorkerUtils.readUTF(stream) - new PythonException(msg, writer.exception.orNull) + new PythonException( + errorClass = "PYTHON_EXCEPTION", + messageParameters = Map("msg" -> msg), + cause = writer.exception.orNull) } protected def handleEndOfDataSection(): Unit = { diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index a4da5ea99838..24d01e63c114 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -177,10 +177,12 @@ object StreamingForeachBatchHelper extends Logging { log"completed (ret: 0)") case SpecialLengths.PYTHON_EXCEPTION_THROWN => val msg = PythonWorkerUtils.readUTF(dataIn) - throw new PythonException( + val errorMsg = s"[session: ${sessionHolder.sessionId}] [userId: ${sessionHolder.userId}] " + - s"Found error inside foreachBatch Python process: $msg", - null) + s"Found error inside foreachBatch Python process: $msg" + throw new PythonException( + errorClass = "PYTHON_EXCEPTION", + messageParameters = Map("msg" -> errorMsg)) case otherValue => throw new IllegalStateException( s"[session: ${sessionHolder.sessionId}] [userId: ${sessionHolder.userId}] " + diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala index f994ada920ec..532789a0426b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala @@ -92,10 +92,11 @@ class PythonStreamingQueryListener(listener: SimplePythonFunction, sessionHolder log"completed (ret: 0)") case SpecialLengths.PYTHON_EXCEPTION_THROWN => val msg = PythonWorkerUtils.readUTF(dataIn) + val errorMsg = s"Found error inside Streaming query listener Python " + + s"process for function $functionName: $msg" throw new PythonException( - s"Found error inside Streaming query listener Python " + - s"process for function $functionName: $msg", - null) + errorClass = "PYTHON_EXCEPTION", + messageParameters = Map("msg" -> errorMsg)) case otherValue => throw new IllegalStateException( s"Unexpected return value $otherValue from the " + From fdd8da613809656a65b3bfb714097b1828ecea43 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Tue, 18 Nov 2025 14:37:14 -0800 Subject: [PATCH 2/2] Fix how Python side handles PythonException --- .../apache/spark/api/python/PythonRDD.scala | 6 ++-- python/pyspark/errors/exceptions/captured.py | 28 ++++++++----------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 029d9fad9d08..45bf30675148 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -128,11 +128,11 @@ private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) /** Thrown for exceptions in user Python code. */ private[spark] class PythonException( msg: String, - cause: Option[Throwable], + cause: Throwable, errorClass: Option[String], messageParameters: Map[String, String], context: Array[QueryContext]) - extends RuntimeException(msg, cause.orNull) with SparkThrowable { + extends RuntimeException(msg, cause) with SparkThrowable { def this( errorClass: String, @@ -142,7 +142,7 @@ private[spark] class PythonException( summary: String = "") = { this( SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), - Option(cause), + cause, Option(errorClass), messageParameters, context diff --git a/python/pyspark/errors/exceptions/captured.py b/python/pyspark/errors/exceptions/captured.py index 0f76e3b5f6a0..98dc747c764a 100644 --- a/python/pyspark/errors/exceptions/captured.py +++ b/python/pyspark/errors/exceptions/captured.py @@ -234,26 +234,20 @@ def _convert_exception(e: "Py4JJavaError") -> CapturedException: return SparkUpgradeException(origin=e) elif is_instance_of(gw, e, "org.apache.spark.SparkNoSuchElementException"): return SparkNoSuchElementException(origin=e) - - c: "Py4JJavaError" = e.getCause() - stacktrace: str = getattr(jvm, "org.apache.spark.util.Utils").exceptionString(e) - if c is not None and ( - is_instance_of(gw, c, "org.apache.spark.api.python.PythonException") + elif is_instance_of(gw, e, "org.apache.spark.api.python.PythonException"): # To make sure this only catches Python UDFs. - and any( + stacktrace = getattr(jvm, "org.apache.spark.util.Utils").exceptionString(e) + if any( map( - lambda v: "org.apache.spark.sql.execution.python" in v.toString(), c.getStackTrace() + lambda v: "org.apache.spark.sql.execution.python" in v.toString(), e.getStackTrace() ) - ) - ): - msg = ( - "\n An exception was thrown from the Python worker. " - "Please see the stack trace below.\n%s" % c.getMessage() - ) - return PythonException(msg, stacktrace) - - return UnknownException(desc=e.toString(), stackTrace=stacktrace, cause=c) - + ): + msg = ( + "\n An exception was thrown from the Python worker. " + "Please see the stack trace below.\n%s" % e.getMessage() + ) + return PythonException(msg, stacktrace) + return UnknownException(desc=e.toString(), stackTrace=getattr(jvm, "org.apache.spark.util.Utils").exceptionString(e), cause=e.getCause()) def capture_sql_exception(f: Callable[..., Any]) -> Callable[..., Any]: def deco(*a: Any, **kw: Any) -> Any: