Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5099,6 +5099,12 @@
],
"sqlState" : "38000"
},
"PYTHON_EXCEPTION" : {
"message": [
"Exception raised from Python worker: <msg>"
],
"sqlState" : "38000"
},
"PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR" : {
"message" : [
"Failed when Python streaming data source perform <action>: <msg>"
Expand Down
30 changes: 28 additions & 2 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,34 @@ private[spark] case class SimplePythonFunction(
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])

/** Thrown for exceptions in user Python code. */
private[spark] class PythonException(msg: String, cause: Throwable)
extends RuntimeException(msg, cause)
private[spark] class PythonException(
msg: String,
cause: Throwable,
errorClass: Option[String],
messageParameters: Map[String, String],
context: Array[QueryContext])
extends RuntimeException(msg, cause) with SparkThrowable {

def this(
errorClass: String,
messageParameters: Map[String, String],
cause: Throwable = null,
context: Array[QueryContext] = Array.empty,
summary: String = "") = {
this(
SparkThrowableHelper.getMessage(errorClass, messageParameters, summary),
cause,
Option(errorClass),
messageParameters,
context
)
}

override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava

override def getCondition: String = errorClass.orNull
override def getQueryContext: Array[QueryContext] = context
}

/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected def handlePythonException(): PythonException = {
// Signals that an exception has been thrown in python
val msg = PythonWorkerUtils.readUTF(stream)
new PythonException(msg, writer.exception.orNull)
new PythonException(
errorClass = "PYTHON_EXCEPTION",
messageParameters = Map("msg" -> msg),
cause = writer.exception.orNull)
}

protected def handleEndOfDataSection(): Unit = {
Expand Down
28 changes: 11 additions & 17 deletions python/pyspark/errors/exceptions/captured.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,26 +234,20 @@ def _convert_exception(e: "Py4JJavaError") -> CapturedException:
return SparkUpgradeException(origin=e)
elif is_instance_of(gw, e, "org.apache.spark.SparkNoSuchElementException"):
return SparkNoSuchElementException(origin=e)

c: "Py4JJavaError" = e.getCause()
stacktrace: str = getattr(jvm, "org.apache.spark.util.Utils").exceptionString(e)
if c is not None and (
is_instance_of(gw, c, "org.apache.spark.api.python.PythonException")
elif is_instance_of(gw, e, "org.apache.spark.api.python.PythonException"):
# To make sure this only catches Python UDFs.
and any(
stacktrace = getattr(jvm, "org.apache.spark.util.Utils").exceptionString(e)
if any(
map(
lambda v: "org.apache.spark.sql.execution.python" in v.toString(), c.getStackTrace()
lambda v: "org.apache.spark.sql.execution.python" in v.toString(), e.getStackTrace()
)
)
):
msg = (
"\n An exception was thrown from the Python worker. "
"Please see the stack trace below.\n%s" % c.getMessage()
)
return PythonException(msg, stacktrace)

return UnknownException(desc=e.toString(), stackTrace=stacktrace, cause=c)

):
msg = (
"\n An exception was thrown from the Python worker. "
"Please see the stack trace below.\n%s" % e.getMessage()
)
return PythonException(msg, stacktrace)
return UnknownException(desc=e.toString(), stackTrace=getattr(jvm, "org.apache.spark.util.Utils").exceptionString(e), cause=e.getCause())

def capture_sql_exception(f: Callable[..., Any]) -> Callable[..., Any]:
def deco(*a: Any, **kw: Any) -> Any:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,12 @@ object StreamingForeachBatchHelper extends Logging {
log"completed (ret: 0)")
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
val msg = PythonWorkerUtils.readUTF(dataIn)
throw new PythonException(
val errorMsg =
s"[session: ${sessionHolder.sessionId}] [userId: ${sessionHolder.userId}] " +
s"Found error inside foreachBatch Python process: $msg",
null)
s"Found error inside foreachBatch Python process: $msg"
throw new PythonException(
errorClass = "PYTHON_EXCEPTION",
messageParameters = Map("msg" -> errorMsg))
case otherValue =>
throw new IllegalStateException(
s"[session: ${sessionHolder.sessionId}] [userId: ${sessionHolder.userId}] " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ class PythonStreamingQueryListener(listener: SimplePythonFunction, sessionHolder
log"completed (ret: 0)")
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
val msg = PythonWorkerUtils.readUTF(dataIn)
val errorMsg = s"Found error inside Streaming query listener Python " +
s"process for function $functionName: $msg"
throw new PythonException(
s"Found error inside Streaming query listener Python " +
s"process for function $functionName: $msg",
null)
errorClass = "PYTHON_EXCEPTION",
messageParameters = Map("msg" -> errorMsg))
case otherValue =>
throw new IllegalStateException(
s"Unexpected return value $otherValue from the " +
Expand Down