Skip to content

Commit e0ee499

Browse files
committed
reduce duplicate code
1 parent c2ef0f2 commit e0ee499

File tree

3 files changed

+38
-26
lines changed

3 files changed

+38
-26
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
7979
val (stateStoreReaderInfo, storeMetadata) = getStoreMetadataAndRunChecks(
8080
sourceOptions)
8181

82-
// Extract stateFormatVersion from StateStoreConf for SYMMETRIC_HASH_JOIN operator
83-
val isJoin = (
84-
storeMetadata.head.operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME)
85-
val stateFormatVersion: Int = if (storeMetadata.nonEmpty && isJoin) {
86-
session.conf.get(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION)
87-
} else {
88-
1
89-
}
82+
val stateFormatVersion = getStateFormatVersion(storeMetadata)
9083

9184
// The key state encoder spec should be available for all operators except stream-stream joins
9285
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
@@ -102,7 +95,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
10295
stateStoreReaderInfo.stateSchemaProviderOpt,
10396
stateStoreReaderInfo.joinColFamilyOpt,
10497
Option(stateStoreReaderInfo.allColumnFamiliesReaderInfo),
105-
Option(stateFormatVersion))
98+
stateFormatVersion)
10699
}
107100

108101
override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
@@ -112,14 +105,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
112105
val (stateStoreReaderInfo, storeMetadata) = getStoreMetadataAndRunChecks(sourceOptions)
113106
val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf)
114107

115-
// Extract stateFormatVersion from StateStoreConf for SYMMETRIC_HASH_JOIN operator
116-
val stateFormatVersion = if (storeMetadata.nonEmpty &&
117-
(storeMetadata.head.operatorName ==
118-
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME)) {
119-
Some(session.conf.get(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION))
120-
} else {
121-
None
122-
}
108+
val stateFormatVersion = getStateFormatVersion(storeMetadata)
123109

124110
val stateCheckpointLocation = sourceOptions.stateCheckpointLocation
125111
try {
@@ -159,6 +145,20 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
159145

160146
override def supportsExternalMetadata(): Boolean = false
161147

148+
/**
149+
* Returns the state format version for SYMMETRIC_HASH_JOIN operators.
150+
* For join operators, returns the configured version; for other operators returns None.
151+
*/
152+
private def getStateFormatVersion(
153+
storeMetadata: Array[StateMetadataTableEntry]): Option[Int] = {
154+
if (storeMetadata.nonEmpty &&
155+
storeMetadata.head.operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME) {
156+
Some(session.conf.get(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION))
157+
} else {
158+
None
159+
}
160+
}
161+
162162
/**
163163
* Returns true if this is a read-all-column-families request for a stream-stream join
164164
* that uses virtual column families (state format version 3).
@@ -389,13 +389,14 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
389389
}
390390
}
391391

392+
val operatorName = if (storeMetadata.nonEmpty) storeMetadata.head.operatorName else ""
392393
(StateStoreReaderInfo(
393394
keyStateEncoderSpecOpt,
394395
stateStoreColFamilySchemaOpt,
395396
transformWithStateVariableInfoOpt,
396397
stateSchemaProvider,
397398
joinColFamilyOpt,
398-
AllColumnFamiliesReaderInfo(stateStoreColFamilySchemas, stateVariableInfos)
399+
AllColumnFamiliesReaderInfo(stateStoreColFamilySchemas, stateVariableInfos, operatorName)
399400
), storeMetadata)
400401
}
401402

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ import org.apache.spark.util.{NextIterator, SerializableConfiguration}
3131

3232
case class AllColumnFamiliesReaderInfo(
3333
colFamilySchemas: List[StateStoreColFamilySchema] = List.empty,
34-
stateVariableInfos: List[TransformWithStateVariableInfo] = List.empty)
34+
stateVariableInfos: List[TransformWithStateVariableInfo] = List.empty,
35+
operatorName: String = "")
3536

3637
/**
3738
* An implementation of [[PartitionReaderFactory]] for State data source. This is used to support
@@ -280,6 +281,15 @@ class StatePartitionAllColumnFamiliesReader(
280281

281282
private val stateStoreColFamilySchemas = allColumnFamiliesReaderInfo.colFamilySchemas
282283
private val stateVariableInfos = allColumnFamiliesReaderInfo.stateVariableInfos
284+
private val operatorName = allColumnFamiliesReaderInfo.operatorName
285+
286+
// Create the extractor for partition key extraction
287+
private lazy val partitionKeyExtractor = SchemaUtil.getExtractor(
288+
operatorName,
289+
keySchema,
290+
partition.sourceOptions.storeName,
291+
stateVariableInfos.headOption,
292+
stateFormatVersion)
283293

284294
private def isListType(colFamilyName: String): Boolean = {
285295
SchemaUtil.checkVariableType(
@@ -364,13 +374,14 @@ class StatePartitionAllColumnFamiliesReader(
364374
pair =>
365375
store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
366376
value =>
367-
SchemaUtil.unifyStateRowPairAsRawBytes((pair.key, value), cfSchema.colFamilyName)
377+
SchemaUtil.unifyStateRowPairAsRawBytes(
378+
(pair.key, value), cfSchema.colFamilyName, partitionKeyExtractor)
368379
}
369380
)
370381
} else {
371382
store.iterator(cfSchema.colFamilyName).map { pair =>
372383
SchemaUtil.unifyStateRowPairAsRawBytes(
373-
(pair.key, pair.value), cfSchema.colFamilyName)
384+
(pair.key, pair.value), cfSchema.colFamilyName, partitionKeyExtractor)
374385
}
375386
}
376387
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,18 +123,18 @@ object SchemaUtil {
123123

124124
/**
125125
* Returns an InternalRow representing
126-
* 1. partitionKey
126+
* 1. partitionKey (extracted using the StatePartitionKeyExtractor)
127127
* 2. key in bytes
128128
* 3. value in bytes
129129
* 4. column family name
130130
*/
131131
def unifyStateRowPairAsRawBytes(
132132
pair: (UnsafeRow, UnsafeRow),
133-
colFamilyName: String): InternalRow = {
133+
colFamilyName: String,
134+
extractor: StatePartitionKeyExtractor): InternalRow = {
134135
val row = new GenericInternalRow(4)
135-
// todo [SPARK-54443]: change keySchema to more specific type after we
136-
// can extract partition key from keySchema
137-
row.update(0, pair._1)
136+
val partitionKey = extractor.partitionKey(pair._1)
137+
row.update(0, partitionKey)
138138
row.update(1, pair._1.getBytes)
139139
row.update(2, pair._2.getBytes)
140140
row.update(3, UTF8String.fromString(colFamilyName))

0 commit comments

Comments
 (0)