Skip to content

Commit c474783

Browse files
committed
add test for internal column
1 parent f1807d7 commit c474783

File tree

4 files changed

+58
-22
lines changed

4 files changed

+58
-22
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqMetadata
3737
import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils
3838
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
3939
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
40-
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
40+
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils, StateVariableType, TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
4141
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
4242
import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
4343
import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata
@@ -193,7 +193,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
193193

194194
val stateVars = twsOperatorProperties.stateVariables
195195
val stateVarInfo = stateVars.filter(stateVar => stateVar.stateName == stateVarName)
196-
if (stateVarInfo.size != 1) {
196+
if (stateVarInfo.size != 1 &&
197+
!StateStoreColumnFamilySchemaUtils.isInternalColFamilyTestOnly(stateVarName)) {
197198
throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
198199
s"State variable $stateVarName is not defined for the transformWithState operator.")
199200
}
@@ -293,8 +294,15 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
293294
if (sourceOptions.internalOnlyReadAllColumnFamilies) {
294295
stateVariableInfos = operatorProperties.stateVariables
295296
} else {
296-
val stateVarInfoList = operatorProperties.stateVariables
297+
var stateVarInfoList = operatorProperties.stateVariables
297298
.filter(stateVar => stateVar.stateName == stateVarName)
299+
if (stateVarInfoList.isEmpty &&
300+
StateStoreColumnFamilySchemaUtils.isInternalColFamilyTestOnly(stateVarName)) {
301+
// pass this dummy TWSStateVariableInfo for TWS internal column family during testing,
302+
stateVarInfoList = List(TransformWithStateVariableInfo(
303+
stateVarName, StateVariableType.ValueState, false
304+
))
305+
}
298306
require(stateVarInfoList.size == 1, s"Failed to find unique state variable info " +
299307
s"for state variable $stateVarName in operator ${sourceOptions.operatorId}")
300308
val stateVarInfo = stateVarInfoList.head

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2222
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
2323
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
2424
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
25-
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
25+
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils, StateVariableType, TransformWithStateVariableInfo}
2626
import org.apache.spark.sql.execution.streaming.state._
2727
import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType}
2828
import org.apache.spark.sql.types.{NullType, StructField, StructType}
@@ -143,14 +143,15 @@ abstract class StatePartitionReaderBase(
143143
useColumnFamilies = useColFamilies, storeConf, hadoopConf.value,
144144
useMultipleValuesPerKey = useMultipleValuesPerKey, stateSchemaProviderOpt)
145145

146-
val isInternal = partition.sourceOptions.readRegisteredTimers
147-
148146
if (useColFamilies) {
149147
val store = provider.getStore(
150148
partition.sourceOptions.batchId + 1,
151149
getEndStoreUniqueId)
152150
require(stateStoreColFamilySchemaOpt.isDefined)
153151
val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get
152+
val isInternal = partition.sourceOptions.readRegisteredTimers ||
153+
StateStoreColumnFamilySchemaUtils.isInternalColFamilyTestOnly(
154+
stateStoreColFamilySchema.colFamilyName)
154155
require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined)
155156
store.createColFamilyIfAbsent(
156157
stateStoreColFamilySchema.colFamilyName,

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,17 @@ object StateStoreColumnFamilySchemaUtils {
9999
def getStateNameFromCountIndexCFName(colFamilyName: String): String =
100100
getStateName(COUNT_INDEX_PREFIX, colFamilyName)
101101

102+
/**
103+
* Returns true if the column family is internal (starts with "$") and we are in testing mode.
104+
* This is used to allow internal column families to be read during tests.
105+
*
106+
* @param colFamilyName The name of the column family to check
107+
* @return true if this is an internal column family and Utils.isTesting is true
108+
*/
109+
def isInternalColFamilyTestOnly(colFamilyName: String): Boolean = {
110+
org.apache.spark.util.Utils.isTesting && colFamilyName.startsWith("$")
111+
}
112+
102113
def getValueStateSchema[T](
103114
stateName: String,
104115
keyEncoder: ExpressionEncoder[Any],

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -826,22 +826,38 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
826826
"listState",
827827
groupByKeySchema,
828828
listStateValueSchema)
829-
830-
// Validate that TTL-related column families have the expected number of entries
831-
val ttlIndexRows = allBytesData.filter(_.getString(3) == "$ttl_listState")
832-
val minExpiryRows = allBytesData.filter(_.getString(3) == "$min_listState")
833-
val countIndexRows = allBytesData.filter(_.getString(3) == "$count_listState")
834-
835-
// We have 2 grouping keys (a, b), so each secondary index should have entries
836-
// TTL index has one entry per unique (expirationMs, groupingKey) pair
837-
// Min expiry and count indexes have one entry per grouping key
838-
assert(minExpiryRows.length == 2,
839-
s"Expected 2 min expiry entries (one per key), got ${minExpiryRows.length}")
840-
assert(countIndexRows.length == 2,
841-
s"Expected 2 count index entries (one per key), got ${countIndexRows.length}")
842-
// TTL index entries depend on batching - we processed 2 batches with different timestamps
843-
assert(ttlIndexRows.length >= 2,
844-
s"Expected at least 2 TTL index entries, got ${ttlIndexRows.length}")
829+
val dummyValueSchema = StructType(Array(StructField("__dummy__", NullType)))
830+
val ttlIndexKeySchema = StructType(Array(
831+
StructField("expirationMs", LongType, nullable = false),
832+
StructField("elementKey", groupByKeySchema)
833+
))
834+
val minExpiryValueSchema = StructType(Array(
835+
StructField("minExpiry", LongType)
836+
))
837+
val countValueSchema = StructType(Array(
838+
StructField("count", LongType)
839+
))
840+
val columnFamilyAndKeyValueSchema = Seq(
841+
("$ttl_listState", ttlIndexKeySchema, dummyValueSchema),
842+
("$min_listState", groupByKeySchema, minExpiryValueSchema),
843+
("$count_listState", groupByKeySchema, countValueSchema)
844+
)
845+
columnFamilyAndKeyValueSchema.foreach(pair => {
846+
val normalDf = spark.read
847+
.format("statestore")
848+
.option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
849+
.option(StateSourceOptions.STATE_VAR_NAME, pair._1)
850+
.load()
851+
.selectExpr("partition_id", "key", "value")
852+
853+
compareNormalAndBytesData(
854+
normalDf.collect(),
855+
allBytesData,
856+
pair._1,
857+
pair._2,
858+
pair._3)
859+
}
860+
)
845861
}
846862
}
847863

0 commit comments

Comments
 (0)