Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import scala.jdk.CollectionConverters._
import scala.math.BigDecimal.RoundingMode
import scala.util.control.NonFatal

import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.databind.{DeserializationContext, DeserializationFeature, JsonDeserializer, JsonNode, ObjectMapper}
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}
import org.json4s._
Expand Down Expand Up @@ -191,7 +192,19 @@ class StreamingQueryProgress private[spark] (
("stateOperators" -> JArray(stateOperators.map(_.jsonValue).toList)) ~
("sources" -> JArray(sources.map(_.jsonValue).toList)) ~
("sink" -> sink.jsonValue) ~
("observedMetrics" -> safeMapToJValue[Row](observedMetrics, (_, row) => row.jsonValue))
("observedMetrics" -> {
// TODO: SPARK-54391
// In Spark connect, the observedMetrics is serialized but is not deserialized properly when
// being sent back to the client and the schema is null. So calling row.jsonValue will throw
// an exception so we need to catch the exception and return JNothing.
// This is because the Row.jsonValue method is a one way method and there is no reverse
// method to convert the JSON back to a Row.
try {
safeMapToJValue[Row](observedMetrics, (_, row) => row.jsonValue)
} catch {
case NonFatal(e) => JNothing
}
})
}
}

Expand All @@ -210,6 +223,19 @@ private[spark] object StreamingQueryProgress {
mapper.readValue[StreamingQueryProgress](json)
}

// SPARK-54390: Custom deserializer that converts JSON objects to strings for offset fields
private class ObjectToStringDeserializer extends JsonDeserializer[String] {
override def deserialize(parser: JsonParser, context: DeserializationContext): String = {
val node: JsonNode = parser.readValueAsTree()
if (node.isTextual) {
node.asText()
} else {
// Convert JSON object/array to string representation
node.toString
}
}
}

/**
* Information about progress made for a source in the execution of a [[StreamingQuery]] during a
* trigger. See [[StreamingQueryProgress]] for more information.
Expand All @@ -233,12 +259,19 @@ private[spark] object StreamingQueryProgress {
@Evolving
class SourceProgress protected[spark] (
val description: String,
// SPARK-54390: Use a custom deserializer to convert the JSON object to a string.
@JsonDeserialize(using = classOf[ObjectToStringDeserializer])
val startOffset: String,
@JsonDeserialize(using = classOf[ObjectToStringDeserializer])
val endOffset: String,
@JsonDeserialize(using = classOf[ObjectToStringDeserializer])
val latestOffset: String,
val numInputRows: Long,
val inputRowsPerSecond: Double,
val processedRowsPerSecond: Double,
// The NaN is used in deserialization to indicate the value was not set.
// The NaN is then used to not output this field in the JSON.
// In Spark connect, we need to ensure that the default value is Double.NaN instead of 0.0.
val inputRowsPerSecond: Double = Double.NaN,
val processedRowsPerSecond: Double = Double.NaN,
val metrics: ju.Map[String, String] = Map[String, String]().asJava)
extends Serializable {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.UUID

import scala.collection.mutable

import org.json4s.jackson.JsonMethods.{compact, parse, render}
import org.scalactic.{Equality, TolerantNumerics}
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.PatienceConfiguration.Timeout
Expand Down Expand Up @@ -286,6 +287,12 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
)
}

private def removeFieldFromJson(jsonString: String, fieldName: String): String = {
val jv = parse(jsonString, useBigDecimalForDouble = true)
val removed = jv.removeField { case (name, _) => name == fieldName }
compact(render(removed))
}

test("QueryProgressEvent serialization") {
def testSerialization(event: QueryProgressEvent): Unit = {
import scala.jdk.CollectionConverters._
Expand All @@ -294,9 +301,24 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
assert(newEvent.progress.json === event.progress.json) // json as a proxy for equality
assert(newEvent.progress.durationMs.asScala === event.progress.durationMs.asScala)
assert(newEvent.progress.eventTime.asScala === event.progress.eventTime.asScala)

// Verify we can get the event back from the JSON string, this is important for Spark Connect
// and the StreamingQueryListenerBus. This is the method that is used to deserialize the event
// in StreamingQueryListenerBus.queryEventHandler
val eventFromNewEvent = QueryProgressEvent.fromJson(newEvent.json)
// TODO: Remove after SC-206585 is fixed
// We remove the observedMetrics field because it is not serialized properly when being
// removed from the listener bus, so this test is to verify that everything expect the
// observedMetrics field is equal in the JSON string
val eventWithoutObservedMetrics = removeFieldFromJson(event.progress.json, "observedMetrics")
assert(eventFromNewEvent.progress.json === eventWithoutObservedMetrics)
}
testSerialization(new QueryProgressEvent(StreamingQueryStatusAndProgressSuite.testProgress1))
testSerialization(new QueryProgressEvent(StreamingQueryStatusAndProgressSuite.testProgress2))
testSerialization(new QueryProgressEvent(StreamingQueryStatusAndProgressSuite.testProgress3))
testSerialization(new QueryProgressEvent(StreamingQueryStatusAndProgressSuite.testProgress4))
testSerialization(new QueryProgressEvent(StreamingQueryStatusAndProgressSuite.testProgress5))
testSerialization(new QueryProgressEvent(StreamingQueryStatusAndProgressSuite.testProgress6))
}

test("QueryTerminatedEvent serialization") {
Expand Down
Loading