@@ -21,8 +21,10 @@ import org.apache.spark.sql.catalyst.InternalRow
2121import org .apache .spark .sql .catalyst .expressions .{GenericInternalRow , UnsafeRow }
2222import org .apache .spark .sql .connector .read .{InputPartition , PartitionReader , PartitionReaderFactory }
2323import org .apache .spark .sql .execution .datasources .v2 .state .utils .SchemaUtil
24+ import org .apache .spark .sql .execution .streaming .operators .stateful .StatefulOperatorsUtils
2425import org .apache .spark .sql .execution .streaming .operators .stateful .join .SymmetricHashJoinStateManager
25- import org .apache .spark .sql .execution .streaming .operators .stateful .transformwithstate .{StateStoreColumnFamilySchemaUtils , StateVariableType , TransformWithStateVariableInfo }
26+ import org .apache .spark .sql .execution .streaming .operators .stateful .transformwithstate .{StateStoreColumnFamilySchemaUtils , StateVariableType , TransformWithStateVariableInfo , TransformWithStateVariableUtils }
27+ import org .apache .spark .sql .execution .streaming .operators .stateful .transformwithstate .timers .TimerStateUtils
2628import org .apache .spark .sql .execution .streaming .state ._
2729import org .apache .spark .sql .execution .streaming .state .RecordType .{getRecordTypeAsString , RecordType }
2830import org .apache .spark .sql .types .{NullType , StructField , StructType }
@@ -283,13 +285,46 @@ class StatePartitionAllColumnFamiliesReader(
283285 private val stateVariableInfos = allColumnFamiliesReaderInfo.stateVariableInfos
284286 private val operatorName = allColumnFamiliesReaderInfo.operatorName
285287
286- // Create the extractor for partition key extraction
287- private lazy val partitionKeyExtractor = SchemaUtil .getExtractor(
288- operatorName,
289- keySchema,
290- partition.sourceOptions.storeName,
291- stateVariableInfos.headOption,
292- stateFormatVersion)
288+ def isDefaultColFamilyInTWS (operatorName : String , colFamilyName : String ): Boolean = {
289+ StatefulOperatorsUtils .TRANSFORM_WITH_STATE_OP_NAMES .contains(operatorName) &&
290+ colFamilyName == StateStore .DEFAULT_COL_FAMILY_NAME
291+ }
292+
293+ // Create extractors for each column family - each column family may have different key schema
294+ private lazy val partitionKeyExtractors : Map [String , StatePartitionKeyExtractor ] = {
295+ stateStoreColFamilySchemas
296+ .filter(schema => ! isDefaultColFamilyInTWS(operatorName, schema.colFamilyName))
297+ .map { cfSchema =>
298+ val colFamilyName = cfSchema.colFamilyName
299+ val colFamilyNameToCheck = if (
300+ StateStoreColumnFamilySchemaUtils .isTtlColFamilyName(colFamilyName)) {
301+ StateStoreColumnFamilySchemaUtils .getStateNameFromTtlColFamily(colFamilyName)
302+ } else if (StateStoreColumnFamilySchemaUtils .isMinExpiryIndexCFName(colFamilyName)) {
303+ StateStoreColumnFamilySchemaUtils .getStateNameFromMinExpiryIndexCFName(colFamilyName)
304+ } else if (StateStoreColumnFamilySchemaUtils .isCountIndexCFName(colFamilyName)) {
305+ StateStoreColumnFamilySchemaUtils .getStateNameFromCountIndexCFName(colFamilyName)
306+ } else if (TransformWithStateVariableUtils .isRowCounterCFName(colFamilyName)) {
307+ TransformWithStateVariableUtils .getStateNameFromRowCounterCFName(colFamilyName)
308+ } else {
309+ colFamilyName
310+ }
311+ var stateVarInfo =
312+ stateVariableInfos.find(_.stateName == colFamilyNameToCheck)
313+ if (stateVarInfo.isEmpty) {
314+ if (TimerStateUtils .isTimerSecondaryIndexCF(colFamilyName)) {
315+ stateVarInfo = Some (TransformWithStateVariableUtils .getTimerState(colFamilyName))
316+ }
317+ }
318+ val extractor = SchemaUtil .getExtractor(
319+ operatorName,
320+ cfSchema.keySchema,
321+ partition.sourceOptions.storeName,
322+ colFamilyName,
323+ stateVarInfo,
324+ stateFormatVersion)
325+ colFamilyName -> extractor
326+ }.toMap
327+ }
293328
294329 private def isListType (colFamilyName : String ): Boolean = {
295330 SchemaUtil .checkVariableType(
@@ -368,22 +403,25 @@ class StatePartitionAllColumnFamiliesReader(
368403
369404 override lazy val iter : Iterator [InternalRow ] = {
370405 // Iterate all column families and concatenate results
371- stateStoreColFamilySchemas.iterator.flatMap { cfSchema =>
372- if (isListType(cfSchema.colFamilyName)) {
373- store.iterator(cfSchema.colFamilyName).flatMap(
374- pair =>
375- store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
376- value =>
377- SchemaUtil .unifyStateRowPairAsRawBytes(
378- (pair.key, value), cfSchema.colFamilyName, partitionKeyExtractor)
379- }
380- )
381- } else {
382- store.iterator(cfSchema.colFamilyName).map { pair =>
383- SchemaUtil .unifyStateRowPairAsRawBytes(
384- (pair.key, pair.value), cfSchema.colFamilyName, partitionKeyExtractor)
406+ stateStoreColFamilySchemas.iterator
407+ .filter(schema => ! isDefaultColFamilyInTWS(operatorName, schema.colFamilyName))
408+ .flatMap { cfSchema =>
409+ val extractor = partitionKeyExtractors(cfSchema.colFamilyName)
410+ if (isListType(cfSchema.colFamilyName)) {
411+ store.iterator(cfSchema.colFamilyName).flatMap(
412+ pair =>
413+ store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
414+ value =>
415+ SchemaUtil .unifyStateRowPairAsRawBytes(
416+ (pair.key, value), cfSchema.colFamilyName, extractor)
417+ }
418+ )
419+ } else {
420+ store.iterator(cfSchema.colFamilyName).map { pair =>
421+ SchemaUtil .unifyStateRowPairAsRawBytes(
422+ (pair.key, pair.value), cfSchema.colFamilyName, extractor)
423+ }
385424 }
386- }
387425 }
388426 }
389427
0 commit comments