Skip to content

Commit d234a97

Browse files
committed
initial commit
reduce duplicate code all tests pass
1 parent b4fe10f commit d234a97

File tree

6 files changed

+273
-98
lines changed

6 files changed

+273
-98
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpoint
4444
import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId}
4545
import org.apache.spark.sql.execution.streaming.state.OfflineStateRepartitionErrors
4646
import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
47+
import org.apache.spark.sql.internal.SQLConf
4748
import org.apache.spark.sql.sources.DataSourceRegister
4849
import org.apache.spark.sql.streaming.TimeMode
4950
import org.apache.spark.sql.types.StructType
@@ -75,8 +76,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
7576
sourceOptions.resolvedCpLocation,
7677
stateConf.providerClass)
7778
}
78-
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
79-
sourceOptions)
79+
val stateStoreReaderInfo = getStoreMetadataAndRunChecks(sourceOptions)
8080

8181
// The key state encoder spec should be available for all operators except stream-stream joins
8282
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
@@ -98,9 +98,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
9898
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
9999
StateSourceOptions.apply(session, hadoopConf, options))
100100

101-
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
102-
sourceOptions)
101+
val stateStoreReaderInfo = getStoreMetadataAndRunChecks(sourceOptions)
103102
val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf)
103+
val allCFReaderInfo = stateStoreReaderInfo.allColumnFamiliesReaderInfo
104104

105105
val stateCheckpointLocation = sourceOptions.stateCheckpointLocation
106106
try {
@@ -120,10 +120,13 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
120120
(resultSchema.keySchema, resultSchema.valueSchema)
121121
}
122122

123+
val stateVarInfo = stateStoreReaderInfo.transformWithStateVariableInfoOpt
123124
SchemaUtil.getSourceSchema(sourceOptions, keySchema,
124125
valueSchema,
125-
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
126-
stateStoreReaderInfo.stateStoreColFamilySchemaOpt)
126+
stateVarInfo,
127+
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
128+
allCFReaderInfo.operatorName,
129+
allCFReaderInfo.stateFormatVersion)
127130
} catch {
128131
case NonFatal(e) =>
129132
throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e)
@@ -132,6 +135,22 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
132135

133136
override def supportsExternalMetadata(): Boolean = false
134137

138+
/**
139+
* Return the state format version for SYMMETRIC_HASH_JOIN operators.
140+
* This currently only support join operators because this function is only used by
141+
* PartitionKeyExtractor and PartitionKeyExtractor only needs state format version for
142+
* join operators.
143+
*/
144+
private def getStateFormatVersion(
145+
storeMetadata: Array[StateMetadataTableEntry]): Option[Int] = {
146+
if (storeMetadata.nonEmpty &&
147+
storeMetadata.head.operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME) {
148+
Some(session.conf.get(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION))
149+
} else {
150+
None
151+
}
152+
}
153+
135154
/**
136155
* Returns true if this is a read-all-column-families request for a stream-stream join
137156
* that uses virtual column families (state format version 3).
@@ -260,8 +279,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
260279
}
261280
}
262281

263-
private def getStoreMetadataAndRunChecks(sourceOptions: StateSourceOptions):
264-
StateStoreReaderInfo = {
282+
private def getStoreMetadataAndRunChecks(
283+
sourceOptions: StateSourceOptions): StateStoreReaderInfo = {
265284
val storeMetadata = StateDataSource.getStateStoreMetadata(sourceOptions, hadoopConf)
266285
if (!sourceOptions.internalOnlyReadAllColumnFamilies) {
267286
// Skip runStateVarChecks when reading all column families (for repartitioning) because:
@@ -296,29 +315,33 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
296315

297316
if (sourceOptions.readRegisteredTimers) {
298317
stateVarName = TimerStateUtils.getTimerStateVarNames(timeMode)._1
318+
} else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
319+
// When reading all column families (for repartitioning) for TWS operator,
320+
// we will just choose a random state as placeholder for default column family,
321+
// because we need to use matching stateVariableInfo and stateStoreColFamilySchemaOpt
322+
// to inferSchema (partitionKey in particular) later
323+
stateVarName = operatorProperties.stateVariables.head.stateName
299324
}
300-
// When reading all column families (for repartitioning), we collect all state variable
301-
// infos instead of validating a specific stateVarName. This skips the normal validation
302-
// logic because we're not reading a specific state variable - we're reading all of them.
325+
303326
if (sourceOptions.internalOnlyReadAllColumnFamilies) {
304327
stateVariableInfos = operatorProperties.stateVariables
305-
} else {
306-
var stateVarInfoList = operatorProperties.stateVariables
307-
.filter(stateVar => stateVar.stateName == stateVarName)
308-
if (stateVarInfoList.isEmpty &&
309-
StateStoreColumnFamilySchemaUtils.isTestingInternalColFamily(stateVarName)) {
310-
// pass this dummy TWSStateVariableInfo for TWS internal column family during testing,
311-
// because internalColumns are not register in operatorProperties.stateVariables,
312-
// thus stateVarInfoList will be empty.
313-
stateVarInfoList = List(TransformWithStateVariableInfo(
314-
stateVarName, StateVariableType.ValueState, false
315-
))
316-
}
317-
require(stateVarInfoList.size == 1, s"Failed to find unique state variable info " +
318-
s"for state variable $stateVarName in operator ${sourceOptions.operatorId}")
319-
val stateVarInfo = stateVarInfoList.head
320-
transformWithStateVariableInfoOpt = Some(stateVarInfo)
321328
}
329+
var stateVarInfoList = operatorProperties.stateVariables
330+
.filter(stateVar => stateVar.stateName == stateVarName)
331+
if (stateVarInfoList.isEmpty &&
332+
StateStoreColumnFamilySchemaUtils.isTestingInternalColFamily(stateVarName)) {
333+
// pass this dummy TWSStateVariableInfo for TWS internal column family during testing,
334+
// because internalColumns are not register in operatorProperties.stateVariables,
335+
// thus stateVarInfoList will be empty.
336+
stateVarInfoList = List(TransformWithStateVariableInfo(
337+
stateVarName, StateVariableType.ValueState, false
338+
))
339+
}
340+
require(stateVarInfoList.size == 1, s"Failed to find unique state variable info " +
341+
s"for state variable $stateVarName in operator ${sourceOptions.operatorId}")
342+
val stateVarInfo = stateVarInfoList.head
343+
transformWithStateVariableInfoOpt = Some(stateVarInfo)
344+
322345
val schemaFilePaths = storeMetadataEntry.stateSchemaFilePaths
323346
val stateSchemaMetadata = StateSchemaMetadata.createStateSchemaMetadata(
324347
sourceOptions.stateCheckpointLocation.toString,
@@ -374,13 +397,16 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
374397
}
375398
}
376399

400+
val operatorName = if (storeMetadata.nonEmpty) storeMetadata.head.operatorName else ""
401+
val stateFormatVersion = getStateFormatVersion(storeMetadata)
377402
StateStoreReaderInfo(
378403
keyStateEncoderSpecOpt,
379404
stateStoreColFamilySchemaOpt,
380405
transformWithStateVariableInfoOpt,
381406
stateSchemaProvider,
382407
joinColFamilyOpt,
383-
AllColumnFamiliesReaderInfo(stateStoreColFamilySchemas, stateVariableInfos)
408+
AllColumnFamiliesReaderInfo(
409+
stateStoreColFamilySchemas, stateVariableInfos, operatorName, stateFormatVersion)
384410
)
385411
}
386412

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2222
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
2323
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
24+
import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils
2425
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
25-
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils, StateVariableType, TransformWithStateVariableInfo}
26+
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils, StateVariableType, TransformWithStateVariableInfo, TransformWithStateVariableUtils}
27+
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
2628
import org.apache.spark.sql.execution.streaming.state._
2729
import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType}
2830
import org.apache.spark.sql.types.{NullType, StructField, StructType}
@@ -31,7 +33,9 @@ import org.apache.spark.util.{NextIterator, SerializableConfiguration}
3133

3234
case class AllColumnFamiliesReaderInfo(
3335
colFamilySchemas: Set[StateStoreColFamilySchema] = Set.empty,
34-
stateVariableInfos: List[TransformWithStateVariableInfo] = List.empty)
36+
stateVariableInfos: List[TransformWithStateVariableInfo] = List.empty,
37+
operatorName: String = "",
38+
stateFormatVersion: Option[Int] = None)
3539

3640
/**
3741
* An implementation of [[PartitionReaderFactory]] for State data source. This is used to support
@@ -278,6 +282,56 @@ class StatePartitionAllColumnFamiliesReader(
278282

279283
private val stateStoreColFamilySchemas = allColumnFamiliesReaderInfo.colFamilySchemas
280284
private val stateVariableInfos = allColumnFamiliesReaderInfo.stateVariableInfos
285+
private val operatorName = allColumnFamiliesReaderInfo.operatorName
286+
private val stateFormatVersion = allColumnFamiliesReaderInfo.stateFormatVersion
287+
288+
private def isDefaultColFamilyInTWS(operatorName: String, colFamilyName: String): Boolean = {
289+
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName) &&
290+
colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME
291+
}
292+
293+
/**
294+
* Extracts the base state variable name from internal column family names.
295+
*/
296+
private def getBaseStateName(colFamilyName: String): String = {
297+
if (StateStoreColumnFamilySchemaUtils.isTtlColFamilyName(colFamilyName)) {
298+
StateStoreColumnFamilySchemaUtils.getStateNameFromTtlColFamily(colFamilyName)
299+
} else if (StateStoreColumnFamilySchemaUtils.isMinExpiryIndexCFName(colFamilyName)) {
300+
StateStoreColumnFamilySchemaUtils.getStateNameFromMinExpiryIndexCFName(colFamilyName)
301+
} else if (StateStoreColumnFamilySchemaUtils.isCountIndexCFName(colFamilyName)) {
302+
StateStoreColumnFamilySchemaUtils.getStateNameFromCountIndexCFName(colFamilyName)
303+
} else if (TransformWithStateVariableUtils.isRowCounterCFName(colFamilyName)) {
304+
TransformWithStateVariableUtils.getStateNameFromRowCounterCFName(colFamilyName)
305+
} else {
306+
colFamilyName
307+
}
308+
}
309+
310+
311+
private def getStateVarInfo(
312+
colFamilyName: String): Option[TransformWithStateVariableInfo] = {
313+
if (TimerStateUtils.isTimerSecondaryIndexCF(colFamilyName)) {
314+
Some(TransformWithStateVariableUtils.getTimerState(colFamilyName))
315+
} else {
316+
stateVariableInfos.find(_.stateName == getBaseStateName(colFamilyName))
317+
}
318+
}
319+
320+
// Create extractors for each column family - each column family may have different key schema
321+
private lazy val partitionKeyExtractors: Map[String, StatePartitionKeyExtractor] = {
322+
stateStoreColFamilySchemas
323+
.filter(schema => !isDefaultColFamilyInTWS(operatorName, schema.colFamilyName))
324+
.map { cfSchema =>
325+
val extractor = SchemaUtil.getPartitionKeyExtractor(
326+
operatorName,
327+
cfSchema.keySchema,
328+
partition.sourceOptions.storeName,
329+
cfSchema.colFamilyName,
330+
getStateVarInfo(cfSchema.colFamilyName),
331+
stateFormatVersion)
332+
cfSchema.colFamilyName -> extractor
333+
}.toMap
334+
}
281335

282336
private def isListType(colFamilyName: String): Boolean = {
283337
SchemaUtil.checkVariableType(
@@ -357,21 +411,25 @@ class StatePartitionAllColumnFamiliesReader(
357411

358412
override lazy val iter: Iterator[InternalRow] = {
359413
// Iterate all column families and concatenate results
360-
stateStoreColFamilySchemas.iterator.flatMap { cfSchema =>
361-
if (isListType(cfSchema.colFamilyName)) {
362-
store.iterator(cfSchema.colFamilyName).flatMap(
363-
pair =>
364-
store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
365-
value =>
366-
SchemaUtil.unifyStateRowPairAsRawBytes((pair.key, value), cfSchema.colFamilyName)
367-
}
368-
)
369-
} else {
370-
store.iterator(cfSchema.colFamilyName).map { pair =>
371-
SchemaUtil.unifyStateRowPairAsRawBytes(
372-
(pair.key, pair.value), cfSchema.colFamilyName)
414+
stateStoreColFamilySchemas.iterator
415+
.filter(schema => !isDefaultColFamilyInTWS(operatorName, schema.colFamilyName))
416+
.flatMap { cfSchema =>
417+
val extractor = partitionKeyExtractors(cfSchema.colFamilyName)
418+
if (isListType(cfSchema.colFamilyName)) {
419+
store.iterator(cfSchema.colFamilyName).flatMap(
420+
pair =>
421+
store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
422+
value =>
423+
SchemaUtil.unifyStateRowPairAsRawBytes(
424+
(pair.key, value), cfSchema.colFamilyName, extractor)
425+
}
426+
)
427+
} else {
428+
store.iterator(cfSchema.colFamilyName).map { pair =>
429+
SchemaUtil.unifyStateRowPairAsRawBytes(
430+
(pair.key, pair.value), cfSchema.colFamilyName, extractor)
431+
}
373432
}
374-
}
375433
}
376434
}
377435

0 commit comments

Comments
 (0)