Skip to content

Commit 996e27b

Browse files
committed
all tests pass
1 parent e0ee499 commit 996e27b

File tree

6 files changed

+187
-80
lines changed

6 files changed

+187
-80
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
135135
valueSchema,
136136
stateVarInfo,
137137
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
138-
storeMetadata,
138+
storeMetadata.headOption.map(_.operatorName).getOrElse(""),
139139
stateFormatVersion)
140140
} catch {
141141
case NonFatal(e) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2222
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
2323
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
24+
import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils
2425
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
25-
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils, StateVariableType, TransformWithStateVariableInfo}
26+
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils, StateVariableType, TransformWithStateVariableInfo, TransformWithStateVariableUtils}
27+
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
2628
import org.apache.spark.sql.execution.streaming.state._
2729
import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType}
2830
import org.apache.spark.sql.types.{NullType, StructField, StructType}
@@ -283,13 +285,46 @@ class StatePartitionAllColumnFamiliesReader(
283285
private val stateVariableInfos = allColumnFamiliesReaderInfo.stateVariableInfos
284286
private val operatorName = allColumnFamiliesReaderInfo.operatorName
285287

286-
// Create the extractor for partition key extraction
287-
private lazy val partitionKeyExtractor = SchemaUtil.getExtractor(
288-
operatorName,
289-
keySchema,
290-
partition.sourceOptions.storeName,
291-
stateVariableInfos.headOption,
292-
stateFormatVersion)
288+
def isDefaultColFamilyInTWS(operatorName: String, colFamilyName: String): Boolean = {
289+
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName) &&
290+
colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME
291+
}
292+
293+
// Create extractors for each column family - each column family may have different key schema
294+
private lazy val partitionKeyExtractors: Map[String, StatePartitionKeyExtractor] = {
295+
stateStoreColFamilySchemas
296+
.filter(schema => !isDefaultColFamilyInTWS(operatorName, schema.colFamilyName))
297+
.map { cfSchema =>
298+
val colFamilyName = cfSchema.colFamilyName
299+
val colFamilyNameToCheck = if (
300+
StateStoreColumnFamilySchemaUtils.isTtlColFamilyName(colFamilyName)) {
301+
StateStoreColumnFamilySchemaUtils.getStateNameFromTtlColFamily(colFamilyName)
302+
} else if (StateStoreColumnFamilySchemaUtils.isMinExpiryIndexCFName(colFamilyName)) {
303+
StateStoreColumnFamilySchemaUtils.getStateNameFromMinExpiryIndexCFName(colFamilyName)
304+
} else if (StateStoreColumnFamilySchemaUtils.isCountIndexCFName(colFamilyName)) {
305+
StateStoreColumnFamilySchemaUtils.getStateNameFromCountIndexCFName(colFamilyName)
306+
} else if (TransformWithStateVariableUtils.isRowCounterCFName(colFamilyName)) {
307+
TransformWithStateVariableUtils.getStateNameFromRowCounterCFName(colFamilyName)
308+
} else {
309+
colFamilyName
310+
}
311+
var stateVarInfo =
312+
stateVariableInfos.find(_.stateName == colFamilyNameToCheck)
313+
if (stateVarInfo.isEmpty) {
314+
if (TimerStateUtils.isTimerSecondaryIndexCF(colFamilyName)) {
315+
stateVarInfo = Some(TransformWithStateVariableUtils.getTimerState(colFamilyName))
316+
}
317+
}
318+
val extractor = SchemaUtil.getExtractor(
319+
operatorName,
320+
cfSchema.keySchema,
321+
partition.sourceOptions.storeName,
322+
colFamilyName,
323+
stateVarInfo,
324+
stateFormatVersion)
325+
colFamilyName -> extractor
326+
}.toMap
327+
}
293328

294329
private def isListType(colFamilyName: String): Boolean = {
295330
SchemaUtil.checkVariableType(
@@ -368,22 +403,25 @@ class StatePartitionAllColumnFamiliesReader(
368403

369404
override lazy val iter: Iterator[InternalRow] = {
370405
// Iterate all column families and concatenate results
371-
stateStoreColFamilySchemas.iterator.flatMap { cfSchema =>
372-
if (isListType(cfSchema.colFamilyName)) {
373-
store.iterator(cfSchema.colFamilyName).flatMap(
374-
pair =>
375-
store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
376-
value =>
377-
SchemaUtil.unifyStateRowPairAsRawBytes(
378-
(pair.key, value), cfSchema.colFamilyName, partitionKeyExtractor)
379-
}
380-
)
381-
} else {
382-
store.iterator(cfSchema.colFamilyName).map { pair =>
383-
SchemaUtil.unifyStateRowPairAsRawBytes(
384-
(pair.key, pair.value), cfSchema.colFamilyName, partitionKeyExtractor)
406+
stateStoreColFamilySchemas.iterator
407+
.filter(schema => !isDefaultColFamilyInTWS(operatorName, schema.colFamilyName))
408+
.flatMap { cfSchema =>
409+
val extractor = partitionKeyExtractors(cfSchema.colFamilyName)
410+
if (isListType(cfSchema.colFamilyName)) {
411+
store.iterator(cfSchema.colFamilyName).flatMap(
412+
pair =>
413+
store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
414+
value =>
415+
SchemaUtil.unifyStateRowPairAsRawBytes(
416+
(pair.key, value), cfSchema.colFamilyName, extractor)
417+
}
418+
)
419+
} else {
420+
store.iterator(cfSchema.colFamilyName).map { pair =>
421+
SchemaUtil.unifyStateRowPairAsRawBytes(
422+
(pair.key, pair.value), cfSchema.colFamilyName, extractor)
423+
}
385424
}
386-
}
387425
}
388426
}
389427

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2626
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
2727
import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, StateSourceOptions}
28-
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
2928
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorsUtils, StatePartitionKeyExtractorFactory}
3029
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.LeftSide
3130
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
@@ -54,13 +53,29 @@ object SchemaUtil {
5453
valueSchema: StructType,
5554
transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo],
5655
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
57-
storeMetadata: Array[StateMetadataTableEntry],
56+
operatorName: String,
5857
stateFormatVersion: Option[Int] = None): StructType = {
5958
if (sourceOptions.internalOnlyReadAllColumnFamilies) {
60-
// Extract partition key schema using StatePartitionKeyExtractor
61-
require(storeMetadata.nonEmpty)
59+
val colFamilyName: String =
60+
if (
61+
operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME
62+
) {
63+
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide).head
64+
} else if (
65+
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES
66+
.contains(operatorName)
67+
) {
68+
require(
69+
transformWithStateVariableInfoOpt.isDefined,
70+
"transformWithStateVariableInfo is required for TransformWithState"
71+
)
72+
transformWithStateVariableInfoOpt.get.stateName
73+
} else {
74+
StateStore.DEFAULT_COL_FAMILY_NAME
75+
}
6276
val extractor = getExtractor(
63-
storeMetadata.head.operatorName, keySchema, sourceOptions.storeName,
77+
operatorName, keySchema, sourceOptions.storeName,
78+
colFamilyName,
6479
transformWithStateVariableInfoOpt, stateFormatVersion)
6580
new StructType()
6681
.add("partition_key", extractor.partitionKeySchema)
@@ -94,15 +109,9 @@ object SchemaUtil {
94109
operatorName: String,
95110
keySchema: StructType,
96111
storeName: String,
112+
colFamilyName: String,
97113
transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo],
98114
stateFormatVersion: Option[Int]): StatePartitionKeyExtractor = {
99-
val colFamilyName: String =
100-
if (operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME) {
101-
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide).head
102-
} else {
103-
transformWithStateVariableInfoOpt.map(_.stateName)
104-
.getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
105-
}
106115
StatePartitionKeyExtractorFactory.create(
107116
operatorName,
108117
keySchema,

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateVariableUtils.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ object TransformWithStateVariableUtils {
6363
def isRowCounterCFName(colFamilyName: String): Boolean = {
6464
colFamilyName.startsWith(ROW_COUNTER_CF_PREFIX)
6565
}
66+
67+
def getStateNameFromRowCounterCFName(colFamilyName: String): String = {
68+
colFamilyName.substring(ROW_COUNTER_CF_PREFIX.length)
69+
}
6670
}
6771

6872
// Enum of possible State Variable types

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,14 @@ object TimerStateUtils {
6161
}
6262

6363
def isTimerSecondaryIndexCF(colFamilyName: String): Boolean = {
64-
assert(isTimerCFName(colFamilyName), s"Column family name must be for a timer: $colFamilyName")
6564
colFamilyName.endsWith(TIMESTAMP_TO_KEY_CF)
6665
}
66+
67+
def getPrimaryIndexFromSecondaryIndexCF(colFamilyName: String): String = {
68+
assert(isTimerSecondaryIndexCF(colFamilyName),
69+
s"Column family name must be for a timer secondary index: $colFamilyName")
70+
colFamilyName.replace(TIMESTAMP_TO_KEY_CF, KEY_TO_TIMESTAMP_CF)
71+
}
6772
}
6873

6974
/**

0 commit comments

Comments
 (0)