Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.connector.read.streaming.SparkDataStream
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2ScanRelation
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeq
import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqBase
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.execution.streaming.runtime.{MicroBatchExecution, StreamExecution, StreamingExecutionRelation}
import org.apache.spark.sql.execution.streaming.runtime.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED}
Expand Down Expand Up @@ -854,7 +854,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase with
true
},
AssertOnQuery { q =>
val latestOffset: Option[(Long, OffsetSeq)] = q.offsetLog.getLatest()
val latestOffset: Option[(Long, OffsetSeqBase)] = q.offsetLog.getLatest()
latestOffset.exists { offset =>
!offset._2.offsets.exists(_.exists(_.json == "{}"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
val offsetLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).offsetLog
offsetLog.get(batchId) match {
case Some(value) =>
val metadata = value.metadata.getOrElse(
val metadata = value.metadataOpt.getOrElse(
throw StateDataSourceErrors.offsetMetadataLogUnavailable(batchId, checkpointLocation)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class AsyncOffsetSeqLog(
* the async write of the batch is completed. Future may also be completed exceptionally
* to indicate some write error.
*/
def addAsync(batchId: Long, metadata: OffsetSeq): CompletableFuture[(Long, Boolean)] = {
def addAsync(batchId: Long, metadata: OffsetSeqBase): CompletableFuture[(Long, Boolean)] = {
require(metadata != null, "'null' metadata cannot written to a metadata log")

def issueAsyncWrite(batchId: Long): CompletableFuture[Long] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,67 @@ import org.apache.spark.sql.execution.streaming.runtime.{MultipleWatermarkPolicy
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf._

trait OffsetSeqBase {
def offsets: Seq[Option[OffsetV2]]

/**
* An ordered collection of offsets, used to track the progress of processing data from one or more
* [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance
* vector clock that must progress linearly forward.
*/
case class OffsetSeq(offsets: Seq[Option[OffsetV2]], metadata: Option[OffsetSeqMetadata] = None) {
def metadataOpt: Option[OffsetSeqMetadata]

override def toString: String = this match {
case offsetMap: OffsetMap =>
offsetMap.offsetsMap.map { case (sourceId, offsetOpt) =>
s"$sourceId: ${offsetOpt.map(_.json).getOrElse("-")}"
}.mkString("{", ", ", "}")
case _ =>
offsets.map(_.map(_.json).getOrElse("-")).mkString("[", ", ", "]")
}

/**
* Unpacks an offset into [[StreamProgress]] by associating each offset with the ordered list of
* sources.
* Unpacks an offset into [[StreamProgress]] by associating each offset with the
* ordered list of sources.
*
* This method is typically used to associate a serialized offset with actual sources (which
* cannot be serialized).
* This method is typically used to associate a serialized offset with actual
* sources (which cannot be serialized).
*/
def toStreamProgress(sources: Seq[SparkDataStream]): StreamProgress = {
assert(!this.isInstanceOf[OffsetMap], "toStreamProgress must be called with map")
assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " +
s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " +
s"Cannot continue.")
s"checkpoint offsets and now there are [${sources.size}] sources requested by " +
s"the query. Cannot continue.")
new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) }
}

override def toString: String =
offsets.map(_.map(_.json).getOrElse("-")).mkString("[", ", ", "]")
/**
* Converts OffsetMap to StreamProgress using source ID mapping.
* This method is specific to OffsetMap and requires a mapping from sourceId to SparkDataStream.
*/
def toStreamProgress(
sources: Seq[SparkDataStream],
sourceIdToSourceMap: Map[String, SparkDataStream]): StreamProgress = {
this match {
case offsetMap: OffsetMap =>
val streamProgressEntries = for {
(sourceId, offsetOpt) <- offsetMap.offsetsMap
offset <- offsetOpt
source <- sourceIdToSourceMap.get(sourceId)
} yield source -> offset
new StreamProgress ++ streamProgressEntries
case _ =>
// Fallback to original method for backward compatibility
toStreamProgress(sources)
}
}
}

/**
* An ordered collection of offsets, used to track the progress of processing data from one or more
* [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance
* vector clock that must progress linearly forward.
*/
case class OffsetSeq(
offsets: Seq[Option[OffsetV2]],
metadata: Option[OffsetSeqMetadata] = None) extends OffsetSeqBase {

override def metadataOpt: Option[OffsetSeqMetadata] = metadata
}

object OffsetSeq {
Expand All @@ -79,6 +116,23 @@ object OffsetSeq {
}


/**
* A map-based collection of offsets, used to track the progress of processing data from one or more
* streaming sources. Each source is identified by a string key (initially sourceId.toString()).
* This replaces the sequence-based approach with a more flexible map-based approach to support
* named source identities.
*/
case class OffsetMap(
offsetsMap: Map[String, Option[OffsetV2]],
metadataOpt: Option[OffsetSeqMetadata] = None) extends OffsetSeqBase {

// OffsetMap does not support sequence-based access
override def offsets: Seq[Option[OffsetV2]] = {
throw new UnsupportedOperationException(
"OffsetMap does not support sequence-based offsets access. Use offsetsMap directly.")
}
}

/**
* Contains metadata associated with a [[OffsetSeq]]. This information is
* persisted to the offset log in the checkpoint location via the [[OffsetSeq]] metadata field.
Expand All @@ -97,7 +151,8 @@ object OffsetSeq {
case class OffsetSeqMetadata(
batchWatermarkMs: Long = 0,
batchTimestampMs: Long = 0,
conf: Map[String, String] = Map.empty) {
conf: Map[String, String] = Map.empty,
version: Int = 1) {
def json: String = Serialization.write(this)(OffsetSeqMetadata.format)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,56 +43,109 @@ import org.apache.spark.sql.execution.streaming.runtime.SerializedOffset
* - // No offset for this source i.e., an invalid JSON string
* {2} // LongOffset 2
* ...
*
* Version 2 format (OffsetMap):
* v2 // version 2
* metadata
* 0:{0} // sourceId:offset
* 1:{3} // sourceId:offset
* ...
*/
class OffsetSeqLog(sparkSession: SparkSession, path: String)
extends HDFSMetadataLog[OffsetSeq](sparkSession, path) {
extends HDFSMetadataLog[OffsetSeqBase](sparkSession, path) {

override protected def deserialize(in: InputStream): OffsetSeq = {
override protected def deserialize(in: InputStream): OffsetSeqBase = {
// called inside a try-finally where the underlying stream is closed in the caller
def parseOffset(value: String): OffsetV2 = value match {
case OffsetSeqLog.SERIALIZED_VOID_OFFSET => null
case json => SerializedOffset(json)
}
val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
if (!lines.hasNext) {
throw new IllegalStateException("Incomplete log file")
}

validateVersion(lines.next(), OffsetSeqLog.VERSION)
val versionStr = lines.next()
val versionInt = validateVersion(versionStr, OffsetSeqLog.MAX_VERSION)

// read metadata
val metadata = lines.next().trim match {
case "" => None
case md => Some(md)
}
import org.apache.spark.util.ArrayImplicits._
OffsetSeq.fill(metadata, lines.map(parseOffset).toArray.toImmutableArraySeq: _*)
if (versionInt == OffsetSeqLog.VERSION_2) {
// deserialize the remaining lines into the offset map
val remainingLines = lines.toArray
// New OffsetMap format: sourceId:offset
val offsetsMap = remainingLines.map { line =>
val colonIndex = line.indexOf(':')
if (colonIndex == -1) {
throw new IllegalStateException(s"Invalid OffsetMap format: $line")
}
val sourceId = line.substring(0, colonIndex)
val offsetStr = line.substring(colonIndex + 1)
val offset = if (offsetStr == OffsetSeqLog.SERIALIZED_VOID_OFFSET) {
None
} else {
Some(OffsetSeqLog.parseOffset(offsetStr))
}
sourceId -> offset
}.toMap
OffsetMap(offsetsMap, metadata.map(OffsetSeqMetadata.apply))
} else {
OffsetSeq.fill(metadata,
lines.map(OffsetSeqLog.parseOffset).toArray.toImmutableArraySeq: _*)
}
}

override protected def serialize(offsetSeq: OffsetSeq, out: OutputStream): Unit = {
override protected def serialize(offsetSeq: OffsetSeqBase, out: OutputStream): Unit = {
// called inside a try-finally where the underlying stream is closed in the caller
out.write(("v" + OffsetSeqLog.VERSION).getBytes(UTF_8))
out.write(("v" + offsetSeq.metadataOpt.map(_.version).getOrElse(OffsetSeqLog.VERSION_1))
.getBytes(UTF_8))

// write metadata
out.write('\n')
out.write(offsetSeq.metadata.map(_.json).getOrElse("").getBytes(UTF_8))
out.write(offsetSeq.metadataOpt.map(_.json).getOrElse("").getBytes(UTF_8))

// write offsets, one per line
offsetSeq.offsets.map(_.map(_.json)).foreach { offset =>
out.write('\n')
offset match {
case Some(json: String) => out.write(json.getBytes(UTF_8))
case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8))
}
offsetSeq match {
case offsetMap: OffsetMap =>
// For OffsetMap, write sourceId:offset pairs, one per line
offsetMap.offsetsMap.foreach { case (sourceId, offsetOpt) =>
out.write('\n')
out.write(sourceId.getBytes(UTF_8))
out.write(':')
offsetOpt match {
case Some(offset) => out.write(offset.json.getBytes(UTF_8))
case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8))
}
}
case _ =>
// Original sequence-based serialization
offsetSeq.offsets.map(_.map(_.json)).foreach { offset =>
out.write('\n')
offset match {
case Some(json: String) => out.write(json.getBytes(UTF_8))
case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8))
}
}
}
}

def offsetSeqMetadataForBatchId(batchId: Long): Option[OffsetSeqMetadata] = {
if (batchId < 0) None else get(batchId).flatMap(_.metadata)
if (batchId < 0) {
None
} else {
get(batchId).flatMap(_.metadataOpt)
}
}
}

object OffsetSeqLog {
private[streaming] val VERSION = 1
private val SERIALIZED_VOID_OFFSET = "-"
private[streaming] val VERSION_1 = 1
private[streaming] val VERSION_2 = 2
private[streaming] val VERSION = VERSION_1 // Default version for backward compatibility
private[streaming] val MAX_VERSION = VERSION_2
private[streaming] val SERIALIZED_VOID_OFFSET = "-"

private[checkpointing] def parseOffset(value: String): OffsetV2 = value match {
case SERIALIZED_VOID_OFFSET => null
case json => SerializedOffset(json)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability}
import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, PartitionOffset, ReadLimit}
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, PartitionOffset, ReadLimit, SparkDataStream}
import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, Write}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.SQLExecution
Expand Down Expand Up @@ -65,6 +65,8 @@ class ContinuousExecution(

@volatile protected var sources: Seq[ContinuousStream] = Seq()

def sourceToIdMap: Map[SparkDataStream, String] = Map.empty

// For use only in test harnesses.
private[sql] var currentEpochCoordinatorId: String = _

Expand Down Expand Up @@ -186,7 +188,7 @@ class ContinuousExecution(
val nextOffsets = offsetLog.get(latestEpochId).getOrElse {
throw new IllegalStateException(
s"Batch $latestEpochId was committed without end epoch offsets!")
}
}.asInstanceOf[OffsetSeq]
committedOffsets = nextOffsets.toStreamProgress(sources)
execCtx.batchId = latestEpochId + 1

Expand All @@ -210,7 +212,8 @@ class ContinuousExecution(
val execCtx = latestExecutionContext

if (execCtx.batchId > 0) {
AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources(Some(offsets), sources)
AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources(
Some(offsets), sources, Map.empty[String, SparkDataStream])
}

val withNewSources: LogicalPlan = logicalPlan transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@ package org.apache.spark.sql.execution.streaming.runtime
import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, SparkDataStream}
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeq
import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqBase

/**
* This feeds "latest seen offset" to the sources that implement AcceptsLatestSeenOffset.
*/
object AcceptsLatestSeenOffsetHandler {
def setLatestSeenOffsetOnSources(
offsets: Option[OffsetSeq],
sources: Seq[SparkDataStream]): Unit = {
offsets: Option[OffsetSeqBase],
sources: Seq[SparkDataStream],
sourceIdMap: Map[String, SparkDataStream]): Unit = {
assertNoAcceptsLatestSeenOffsetWithDataSourceV1(sources)

offsets.map(_.toStreamProgress(sources)) match {
offsets.map(_.toStreamProgress(sources, sourceIdMap)) match {
case Some(streamProgress) =>
streamProgress.foreach {
case (src: AcceptsLatestSeenOffset, offset) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.streaming.WriteToStream
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, OneTimeTrigger, ProcessingTimeTrigger}
import org.apache.spark.sql.execution.streaming.checkpointing.{AsyncCommitLog, AsyncOffsetSeqLog, CommitMetadata, OffsetSeq}
import org.apache.spark.sql.execution.streaming.checkpointing.{AsyncCommitLog, AsyncOffsetSeqLog, CommitMetadata, OffsetSeqBase}
import org.apache.spark.sql.execution.streaming.operators.stateful.StateStoreWriter
import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.util.{Clock, ThreadUtils}
Expand All @@ -49,7 +49,7 @@ class AsyncProgressTrackingMicroBatchExecution(
// Offsets that are ready to be committed by the source.
// This is needed so that we can call source commit in the same thread as micro-batch execution
// to be thread safe
private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeqBase]()

// to cache the batch id of the last batch written to storage
private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
Expand Down Expand Up @@ -104,7 +104,7 @@ class AsyncProgressTrackingMicroBatchExecution(
// perform quick validation to fail faster
validateAndGetTrigger()

override def validateOffsetLogAndGetPrevOffset(latestBatchId: Long): Option[OffsetSeq] = {
override def validateOffsetLogAndGetPrevOffset(latestBatchId: Long): Option[OffsetSeqBase] = {
/* Initialize committed offsets to a committed batch, which at this
* is the second latest batch id in the offset log.
* The offset log may not be contiguous */
Expand Down Expand Up @@ -137,14 +137,15 @@ class AsyncProgressTrackingMicroBatchExecution(
// Because we are using a thread pool with only one thread, async writes to the offset log
// are still written in a serial / in order fashion
offsetLog
.addAsync(execCtx.batchId, execCtx.endOffsets.toOffsetSeq(sources, execCtx.offsetSeqMetadata))
.thenAccept(tuple => {
val (batchId, persistedToDurableStorage) = tuple
.addAsync(execCtx.batchId,
execCtx.endOffsets.toOffsets(sources, sourceIdMap, execCtx.offsetSeqMetadata))
.thenAccept((tuple: (Long, Boolean)) => {
val (batchId: Long, persistedToDurableStorage: Boolean) = tuple
if (persistedToDurableStorage) {
// batch id cache not initialized
if (lastBatchPersistedToDurableStorage.get == -1) {
lastBatchPersistedToDurableStorage.set(
offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1L))
}

if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
Expand Down
Loading