Skip to content

Commit 482ed57

Browse files
committed
address comment
1 parent f4cdc5f commit 482ed57

File tree

8 files changed

+89
-45
lines changed

8 files changed

+89
-45
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,6 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
8181
// The key state encoder spec should be available for all operators except stream-stream joins
8282
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
8383
stateStoreReaderInfo.keyStateEncoderSpecOpt.get
84-
} else if (isReadAllColFamiliesOnJoinV3(sourceOptions)) {
85-
// Create keyStateEncoderSpec here because getStoreMetadataAndRunChecks
86-
// doesn't assign it in stateStoreReaderInfo.keyStateEncoderSpecOpt
87-
NoPrefixKeyStateEncoderSpec(
88-
stateStoreReaderInfo.allColumnFamiliesReaderInfo.colFamilySchemas.head.keySchema)
8984
} else {
9085
val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
9186
NoPrefixKeyStateEncoderSpec(keySchema)
@@ -119,17 +114,10 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
119114
sourceOptions.operatorId, RightSide, oldSchemaFilePaths)
120115

121116
case JoinSideValues.none =>
122-
if (isReadAllColFamiliesOnJoinV3(sourceOptions)) {
123-
// readAllColumnFamiliesReader on joinV3 reads schema with StreamStreamJoinStateHelper
124-
StreamStreamJoinStateHelper.readKeyValueSchema(session,
125-
stateCheckpointLocation.toString,
126-
sourceOptions.operatorId, LeftSide, oldSchemaFilePaths)
127-
} else {
128-
// we should have the schema for the state store if joinSide is none
129-
require(stateStoreReaderInfo.stateStoreColFamilySchemaOpt.isDefined)
130-
val resultSchema = stateStoreReaderInfo.stateStoreColFamilySchemaOpt.get
131-
(resultSchema.keySchema, resultSchema.valueSchema)
132-
}
117+
// we should have the schema for the state store if joinSide is none
118+
require(stateStoreReaderInfo.stateStoreColFamilySchemaOpt.isDefined)
119+
val resultSchema = stateStoreReaderInfo.stateStoreColFamilySchemaOpt.get
120+
(resultSchema.keySchema, resultSchema.valueSchema)
133121
}
134122

135123
SchemaUtil.getSourceSchema(sourceOptions, keySchema,
@@ -148,8 +136,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
148136
* Returns true if this is a read-all-column-families request for a stream-stream join
149137
* that uses virtual column families (state format version 3).
150138
*/
151-
private def isReadAllColFamiliesOnJoinV3(sourceOptions: StateSourceOptions): Boolean = {
152-
val storeMetadata = StateDataSource.getStateStoreMetadata(sourceOptions, hadoopConf)
139+
private def isReadAllColFamiliesOnJoinV3(
140+
sourceOptions: StateSourceOptions,
141+
storeMetadata: Array[StateMetadataTableEntry]): Boolean = {
153142
sourceOptions.internalOnlyReadAllColumnFamilies &&
154143
storeMetadata.head.operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME &&
155144
StreamStreamJoinStateHelper.usesVirtualColumnFamilies(
@@ -270,6 +259,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
270259
StateStoreReaderInfo = {
271260
val storeMetadata = StateDataSource.getStateStoreMetadata(sourceOptions, hadoopConf)
272261
if (!sourceOptions.internalOnlyReadAllColumnFamilies) {
262+
// skipping runStateVarChecks for StatePartitionAllColumnFamiliesReader because
263+
// we won't specify any stateVars when querying a TWS operator
273264
runStateVarChecks(sourceOptions, storeMetadata)
274265
}
275266

@@ -299,15 +290,15 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
299290
if (sourceOptions.readRegisteredTimers) {
300291
stateVarName = TimerStateUtils.getTimerStateVarNames(timeMode)._1
301292
}
302-
if (!sourceOptions.internalOnlyReadAllColumnFamilies) {
293+
if (sourceOptions.internalOnlyReadAllColumnFamilies) {
294+
stateVariableInfos = operatorProperties.stateVariables
295+
} else {
303296
val stateVarInfoList = operatorProperties.stateVariables
304297
.filter(stateVar => stateVar.stateName == stateVarName)
305298
require(stateVarInfoList.size == 1, s"Failed to find unique state variable info " +
306299
s"for state variable $stateVarName in operator ${sourceOptions.operatorId}")
307300
val stateVarInfo = stateVarInfoList.head
308301
transformWithStateVariableInfoOpt = Some(stateVarInfo)
309-
} else {
310-
stateVariableInfos = operatorProperties.stateVariables
311302
}
312303
val schemaFilePaths = storeMetadataEntry.stateSchemaFilePaths
313304
val stateSchemaMetadata = StateSchemaMetadata.createStateSchemaMetadata(
@@ -335,8 +326,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
335326
val storeId = new StateStoreId(stateCheckpointLocation.toString, sourceOptions.operatorId,
336327
partitionId, sourceOptions.storeName)
337328
val providerId = new StateStoreProviderId(storeId, UUID.randomUUID())
338-
val manager = new StateSchemaCompatibilityChecker(
339-
providerId, hadoopConf, oldSchemaFilePaths)
329+
val manager = new StateSchemaCompatibilityChecker(providerId, hadoopConf,
330+
oldSchemaFilePaths = oldSchemaFilePaths)
340331
val stateSchema = manager.readSchemaFile()
341332

342333
if (sourceOptions.internalOnlyReadAllColumnFamilies) {
@@ -345,16 +336,18 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
345336
}
346337
// When reading all column families for Join V3, no specific state variable is targeted,
347338
// so stateVarName defaults to DEFAULT_COL_FAMILY_NAME.
348-
// However, Join V3 does not have a "default" column family. Therefore, we skip populating
349-
// keyStateEncoderSpec and stateStoreColFamilySchemaOpt in this case, as there is no
350-
// matching schema for the default column family name.
351-
if (!isReadAllColFamiliesOnJoinV3(sourceOptions)) {
352-
// Based on the version and read schema, populate the keyStateEncoderSpec used for
353-
// reading the column families
354-
val resultSchema = stateSchema.filter(_.colFamilyName == stateVarName).head
355-
keyStateEncoderSpecOpt = Some(getKeyStateEncoderSpec(resultSchema, storeMetadata))
356-
stateStoreColFamilySchemaOpt = Some(resultSchema)
339+
// However, Join V3 does not have a "default" column family. Therefore, we pick the first
340+
// schema as resultSchema which will be used as placeholder schema for default schema
341+
// in StatePartitionAllColumnFamiliesReader
342+
val resultSchema = if (isReadAllColFamiliesOnJoinV3(sourceOptions, storeMetadata)) {
343+
stateSchema.head
344+
} else {
345+
stateSchema.filter(_.colFamilyName == stateVarName).head
357346
}
347+
// Based on the version and read schema, populate the keyStateEncoderSpec used for
348+
// reading the column families
349+
keyStateEncoderSpecOpt = Some(getKeyStateEncoderSpec(resultSchema, storeMetadata))
350+
stateStoreColFamilySchemaOpt = Some(resultSchema)
358351
} catch {
359352
case NonFatal(ex) =>
360353
throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, ex)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ import org.apache.spark.unsafe.types.UTF8String
3030
import org.apache.spark.util.{NextIterator, SerializableConfiguration}
3131

3232
case class AllColumnFamiliesReaderInfo(
33-
colFamilySchemas: List[StateStoreColFamilySchema] = List.empty,
34-
stateVariableInfos: List[TransformWithStateVariableInfo] = List.empty)
33+
colFamilySchemas: List[StateStoreColFamilySchema] = List.empty,
34+
stateVariableInfos: List[TransformWithStateVariableInfo] = List.empty)
3535

3636
/**
3737
* An implementation of [[PartitionReaderFactory]] for State data source. This is used to support
@@ -55,9 +55,10 @@ class StatePartitionReaderFactory(
5555
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
5656
val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition]
5757
if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) {
58+
require(allColumnFamiliesReaderInfo.isDefined)
5859
new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf,
59-
stateStoreInputPartition, schema, keyStateEncoderSpec,
60-
allColumnFamiliesReaderInfo.getOrElse(AllColumnFamiliesReaderInfo()))
60+
stateStoreInputPartition, schema, keyStateEncoderSpec, stateStoreColFamilySchemaOpt,
61+
stateSchemaProviderOpt, allColumnFamiliesReaderInfo.get)
6162
} else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
6263
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
6364
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
@@ -87,23 +88,25 @@ abstract class StatePartitionReaderBase(
8788
extends PartitionReader[InternalRow] with Logging {
8889
// Used primarily as a placeholder for the value schema in the context of
8990
// state variables used within the transformWithState operator.
90-
private val dummySchema: StructType =
91+
private val schemaForValueRow: StructType =
9192
StructType(Array(StructField("__dummy__", NullType)))
9293

9394
protected val keySchema : StructType = {
9495
if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) {
9596
SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions)
9697
} else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) {
97-
stateStoreColFamilySchemaOpt.map(_.keySchema).getOrElse(dummySchema)
98+
require(stateStoreColFamilySchemaOpt.isDefined)
99+
stateStoreColFamilySchemaOpt.get.keySchema
98100
} else {
99101
SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
100102
}
101103
}
102104

103105
protected val valueSchema : StructType = if (stateVariableInfoOpt.isDefined) {
104-
dummySchema
106+
schemaForValueRow
105107
} else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) {
106-
stateStoreColFamilySchemaOpt.map(_.valueSchema).getOrElse(dummySchema)
108+
require(stateStoreColFamilySchemaOpt.isDefined)
109+
stateStoreColFamilySchemaOpt.get.valueSchema
107110
} else {
108111
SchemaUtil.getSchemaAsDataType(
109112
schema, "value").asInstanceOf[StructType]
@@ -262,14 +265,15 @@ class StatePartitionAllColumnFamiliesReader(
262265
partition: StateStoreInputPartition,
263266
schema: StructType,
264267
keyStateEncoderSpec: KeyStateEncoderSpec,
268+
defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
269+
stateSchemaProviderOpt: Option[StateSchemaProvider],
265270
allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
266271
extends StatePartitionReaderBase(
267272
storeConf,
268273
hadoopConf, partition, schema,
269274
keyStateEncoderSpec, None,
270-
allColumnFamiliesReaderInfo.colFamilySchemas.find(
271-
_.colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME),
272-
None, None) {
275+
defaultStateStoreColFamilySchemaOpt,
276+
stateSchemaProviderOpt, None) {
273277

274278
private val stateStoreColFamilySchemas = allColumnFamiliesReaderInfo.colFamilySchemas
275279
private val stateVariableInfos = allColumnFamiliesReaderInfo.stateVariableInfos
@@ -280,7 +284,6 @@ class StatePartitionAllColumnFamiliesReader(
280284
StateVariableType.ListState)
281285
}
282286

283-
// Override provider to register ALL column families
284287
override protected lazy val provider: StateStoreProvider = {
285288
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
286289
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
@@ -289,7 +292,7 @@ class StatePartitionAllColumnFamiliesReader(
289292
StateStoreProvider.createAndInit(
290293
stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
291294
useColumnFamilies, storeConf, hadoopConf.value,
292-
useMultipleValuesPerKey = false, stateSchemaProvider = None)
295+
useMultipleValuesPerKey = false, stateSchemaProviderOpt)
293296
}
294297

295298
// Use a single store instance for both registering column families and iteration.
@@ -303,8 +306,29 @@ class StatePartitionAllColumnFamiliesReader(
303306
getStartStoreUniqueId
304307
)
305308

309+
def checkAllColFamiliesExist(colFamilyNames: List[String]) = {
310+
// Filter out DEFAULT column family from validation for two reasons:
311+
// 1. Some operators (e.g., stream-stream join v3) don't include DEFAULT in their schema
312+
// because the underlying RocksDB creates "default" column family automatically
313+
// 2. The default column family schema is handled separately via
314+
// defaultStateStoreColFamilySchemaOpt, so no need to verify it here
315+
val actualCFs = colFamilyNames.toSet.filter(_ != StateStore.DEFAULT_COL_FAMILY_NAME)
316+
val expectedCFs = stateStore.allColumnFamilyNames
317+
.filter(_ != StateStore.DEFAULT_COL_FAMILY_NAME)
318+
319+
// Validation: All column families found in the checkpoint must be declared in the schema.
320+
// It's acceptable if some schema CFs are not in expectedCFs - this just means those
321+
// column families have no data yet in the checkpoint
322+
// (they'll be created during registration).
323+
// However, if the checkpoint contains CFs not in the schema, it indicates a mismatch.
324+
require(expectedCFs.subsetOf(actualCFs),
325+
s"Checkpoint contains unexpected column families. " +
326+
s"Column families in checkpoint but not in schema: ${expectedCFs.diff(actualCFs)}")
327+
}
328+
306329
// Register all column families from the schema
307330
if (stateStoreColFamilySchemas.length > 1) {
331+
checkAllColFamiliesExist(stateStoreColFamilySchemas.map(_.colFamilyName))
308332
stateStoreColFamilySchemas.foreach { cfSchema =>
309333
cfSchema.colFamilyName match {
310334
case StateStore.DEFAULT_COL_FAMILY_NAME => // createAndInit has registered default

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
146146
throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName)
147147
}
148148

149+
override def allColumnFamilyNames: collection.Set[String] =
150+
collection.Set[String](StateStore.DEFAULT_COL_FAMILY_NAME)
151+
149152
// Multiple col families are not supported with HDFSBackedStateStoreProvider. Throw an exception
150153
// if the user tries to use a non-default col family.
151154
private def assertUseOfDefaultColFamily(colFamilyName: String): Unit = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,14 @@ class RocksDB(
336336
colFamilyNameToInfoMap.asScala.values.toSeq.count(_.isInternal == isInternal)
337337
}
338338

339+
/**
340+
* Returns all column family names currently registered in RocksDB.
341+
* This includes column families loaded from checkpoint metadata.
342+
*/
343+
def allColumnFamilyNames: collection.Set[String] = {
344+
colFamilyNameToInfoMap.asScala.keySet.toSet
345+
}
346+
339347
private val rocksDBFileMapping: RocksDBFileMapping = new RocksDBFileMapping()
340348

341349
// We send snapshots that needs to be uploaded by the maintenance thread to this queue

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,10 @@ private[sql] class RocksDBStateStoreProvider
610610

611611
override def hasCommitted: Boolean = state == COMMITTED
612612

613+
override def allColumnFamilyNames: collection.Set[String] = {
614+
rocksDB.allColumnFamilyNames
615+
}
616+
613617
override def toString: String = {
614618
s"RocksDBStateStore[stateStoreId=$stateStoreId_, version=$version]"
615619
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,13 @@ trait StateStore extends ReadStateStore {
310310
* Whether all updates have been committed
311311
*/
312312
def hasCommitted: Boolean
313+
314+
/**
315+
* Returns all column family names in this state store.
316+
*
317+
* @return Set of all column family names
318+
*/
319+
def allColumnFamilyNames: collection.Set[String]
313320
}
314321

315322
/** Wraps the instance of StateStore to make the instance read-only. */

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class MemoryStateStore extends StateStore() {
4242
throw StateStoreErrors.multipleColumnFamiliesNotSupported("MemoryStateStoreProvider")
4343
}
4444

45+
override def allColumnFamilyNames: collection.Set[String] =
46+
collection.Set[String](StateStore.DEFAULT_COL_FAMILY_NAME)
47+
4548
override def removeColFamilyIfExists(colFamilyName: String): Boolean = {
4649
throw StateStoreErrors.removingColumnFamiliesNotSupported("MemoryStateStoreProvider")
4750
}

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta
122122
)
123123
}
124124

125+
override def allColumnFamilyNames: collection.Set[String] = innerStore.allColumnFamilyNames
126+
125127
override def put(
126128
key: UnsafeRow,
127129
value: UnsafeRow,

0 commit comments

Comments
 (0)