Skip to content

Commit dcf296f

Browse files
committed
initial commit
1 parent 482ed57 commit dcf296f

File tree

5 files changed

+99
-28
lines changed

5 files changed

+99
-28
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpoint
4444
import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId}
4545
import org.apache.spark.sql.execution.streaming.state.OfflineStateRepartitionErrors
4646
import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
47+
import org.apache.spark.sql.internal.SQLConf
4748
import org.apache.spark.sql.sources.DataSourceRegister
4849
import org.apache.spark.sql.streaming.TimeMode
4950
import org.apache.spark.sql.types.StructType
@@ -75,9 +76,18 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
7576
sourceOptions.resolvedCpLocation,
7677
stateConf.providerClass)
7778
}
78-
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
79+
val (stateStoreReaderInfo, storeMetadata) = getStoreMetadataAndRunChecks(
7980
sourceOptions)
8081

82+
// Extract stateFormatVersion from StateStoreConf for SYMMETRIC_HASH_JOIN operator
83+
val isJoin = (
84+
storeMetadata.head.operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME)
85+
val stateFormatVersion: Int = if (storeMetadata.nonEmpty && isJoin) {
86+
session.conf.get(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION)
87+
} else {
88+
1
89+
}
90+
8191
// The key state encoder spec should be available for all operators except stream-stream joins
8292
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
8393
stateStoreReaderInfo.keyStateEncoderSpecOpt.get
@@ -91,19 +101,28 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
91101
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
92102
stateStoreReaderInfo.stateSchemaProviderOpt,
93103
stateStoreReaderInfo.joinColFamilyOpt,
94-
Option(stateStoreReaderInfo.allColumnFamiliesReaderInfo))
104+
Option(stateStoreReaderInfo.allColumnFamiliesReaderInfo),
105+
Option(stateFormatVersion))
95106
}
96107

97108
override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
98109
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
99110
StateSourceOptions.apply(session, hadoopConf, options))
100111

101-
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
102-
sourceOptions)
112+
val (stateStoreReaderInfo, storeMetadata) = getStoreMetadataAndRunChecks(sourceOptions)
103113
val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf)
104114

115+
// Extract stateFormatVersion from StateStoreConf for SYMMETRIC_HASH_JOIN operator
116+
val stateFormatVersion = if (storeMetadata.nonEmpty &&
117+
(storeMetadata.head.operatorName ==
118+
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME)) {
119+
Some(session.conf.get(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION))
120+
} else {
121+
None
122+
}
123+
105124
val stateCheckpointLocation = sourceOptions.stateCheckpointLocation
106-
try {
125+
// try {
107126
val (keySchema, valueSchema) = sourceOptions.joinSide match {
108127
case JoinSideValues.left =>
109128
StreamStreamJoinStateHelper.readKeyValueSchema(session, stateCheckpointLocation.toString,
@@ -120,14 +139,24 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
120139
(resultSchema.keySchema, resultSchema.valueSchema)
121140
}
122141

142+
println("transformWithStateVariableInfoOpt",
143+
stateStoreReaderInfo.transformWithStateVariableInfoOpt)
144+
val stateVarInfo: Option[TransformWithStateVariableInfo] = if (
145+
sourceOptions.internalOnlyReadAllColumnFamilies) {
146+
Option(stateStoreReaderInfo.allColumnFamiliesReaderInfo.stateVariableInfos.head)
147+
} else {
148+
stateStoreReaderInfo.transformWithStateVariableInfoOpt
149+
}
123150
SchemaUtil.getSourceSchema(sourceOptions, keySchema,
124151
valueSchema,
125-
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
126-
stateStoreReaderInfo.stateStoreColFamilySchemaOpt)
127-
} catch {
128-
case NonFatal(e) =>
129-
throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e)
130-
}
152+
stateVarInfo,
153+
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
154+
storeMetadata,
155+
stateFormatVersion)
156+
// } catch {
157+
// case NonFatal(e) =>
158+
// throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e)
159+
// }
131160
}
132161

133162
override def supportsExternalMetadata(): Boolean = false
@@ -256,7 +285,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
256285
}
257286

258287
private def getStoreMetadataAndRunChecks(sourceOptions: StateSourceOptions):
259-
StateStoreReaderInfo = {
288+
(StateStoreReaderInfo, Array[StateMetadataTableEntry]) = {
260289
val storeMetadata = StateDataSource.getStateStoreMetadata(sourceOptions, hadoopConf)
261290
if (!sourceOptions.internalOnlyReadAllColumnFamilies) {
262291
// skipping runStateVarChecks for StatePartitionAllColumnFamiliesReader because
@@ -354,14 +383,14 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
354383
}
355384
}
356385

357-
StateStoreReaderInfo(
386+
(StateStoreReaderInfo(
358387
keyStateEncoderSpecOpt,
359388
stateStoreColFamilySchemaOpt,
360389
transformWithStateVariableInfoOpt,
361390
stateSchemaProvider,
362391
joinColFamilyOpt,
363392
AllColumnFamiliesReaderInfo(stateStoreColFamilySchemas, stateVariableInfos)
364-
)
393+
), storeMetadata)
365394
}
366395

367396
private def getKeyStateEncoderSpec(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class StatePartitionReaderFactory(
4949
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
5050
stateSchemaProviderOpt: Option[StateSchemaProvider],
5151
joinColFamilyOpt: Option[String],
52-
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo])
52+
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo],
53+
stateFormatVersion: Option[Int])
5354
extends PartitionReaderFactory {
5455

5556
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
@@ -58,7 +59,7 @@ class StatePartitionReaderFactory(
5859
require(allColumnFamiliesReaderInfo.isDefined)
5960
new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf,
6061
stateStoreInputPartition, schema, keyStateEncoderSpec, stateStoreColFamilySchemaOpt,
61-
stateSchemaProviderOpt, allColumnFamiliesReaderInfo.get)
62+
stateSchemaProviderOpt, allColumnFamiliesReaderInfo.get, stateFormatVersion)
6263
} else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
6364
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
6465
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
@@ -267,7 +268,8 @@ class StatePartitionAllColumnFamiliesReader(
267268
keyStateEncoderSpec: KeyStateEncoderSpec,
268269
defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
269270
stateSchemaProviderOpt: Option[StateSchemaProvider],
270-
allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
271+
allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo,
272+
stateFormatVersion: Option[Int])
271273
extends StatePartitionReaderBase(
272274
storeConf,
273275
hadoopConf, partition, schema,

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ class StateScanBuilder(
4747
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
4848
stateSchemaProviderOpt: Option[StateSchemaProvider],
4949
joinColFamilyOpt: Option[String],
50-
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]) extends ScanBuilder {
50+
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo],
51+
stateFormatVersion: Option[Int]) extends ScanBuilder {
5152
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf,
5253
keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
53-
joinColFamilyOpt, allColumnFamiliesReaderInfo)
54+
joinColFamilyOpt, allColumnFamiliesReaderInfo, stateFormatVersion)
5455
}
5556

5657
/** An implementation of [[InputPartition]] for State Store data source. */
@@ -70,7 +71,8 @@ class StateScan(
7071
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
7172
stateSchemaProviderOpt: Option[StateSchemaProvider],
7273
joinColFamilyOpt: Option[String],
73-
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo])
74+
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo],
75+
stateFormatVersion: Option[Int])
7476
extends Scan with Batch {
7577

7678
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
@@ -146,7 +148,7 @@ class StateScan(
146148
case JoinSideValues.none =>
147149
new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema,
148150
keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt,
149-
stateSchemaProviderOpt, joinColFamilyOpt, allColumnFamiliesReaderInfo)
151+
stateSchemaProviderOpt, joinColFamilyOpt, allColumnFamiliesReaderInfo, stateFormatVersion)
150152
}
151153

152154
override def toBatch: Batch = this

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class StateTable(
4646
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
4747
stateSchemaProviderOpt: Option[StateSchemaProvider],
4848
joinColFamilyOpt: Option[String],
49-
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo] = None)
49+
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo] = None,
50+
stateFormatVersion: Option[Int] = None)
5051
extends Table with SupportsRead with SupportsMetadataColumns {
5152

5253
import StateTable._
@@ -88,7 +89,7 @@ class StateTable(
8889
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
8990
new StateScanBuilder(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
9091
stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
91-
joinColFamilyOpt, allColumnFamiliesReaderInfo)
92+
joinColFamilyOpt, allColumnFamiliesReaderInfo, stateFormatVersion)
9293

9394
override def properties(): util.Map[String, String] = Map.empty[String, String].asJava
9495

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2626
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
2727
import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, StateSourceOptions}
28+
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
29+
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorsUtils, StatePartitionKeyExtractorFactory}
30+
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.LeftSide
31+
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
2832
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
2933
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateVariableType._
30-
import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStoreColFamilySchema, UnsafeRowPair}
34+
import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StatePartitionKeyExtractor, StateStore, StateStoreColFamilySchema, UnsafeRowPair}
3135
import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, IntegerType, LongType, MapType, StringType, StructType}
3236
import org.apache.spark.unsafe.types.UTF8String
3337
import org.apache.spark.util.ArrayImplicits._
@@ -49,7 +53,10 @@ object SchemaUtil {
4953
keySchema: StructType,
5054
valueSchema: StructType,
5155
transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo],
52-
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]): StructType = {
56+
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
57+
storeMetadata: Array[StateMetadataTableEntry],
58+
stateFormatVersion: Option[Int] = None): StructType = {
59+
println("stateFormatVersion", stateFormatVersion)
5360
if (transformWithStateVariableInfoOpt.isDefined) {
5461
require(stateStoreColFamilySchemaOpt.isDefined)
5562
generateSchemaForStateVar(transformWithStateVariableInfoOpt.get,
@@ -62,10 +69,13 @@ object SchemaUtil {
6269
.add("value", valueSchema)
6370
.add("partition_id", IntegerType)
6471
} else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
72+
// Extract partition key schema using StatePartitionKeyExtractor
73+
require(storeMetadata.nonEmpty)
74+
val extractor = getExtractor(
75+
storeMetadata.head.operatorName, keySchema, sourceOptions.storeName,
76+
transformWithStateVariableInfoOpt, stateFormatVersion)
6577
new StructType()
66-
// TODO [SPARK-54443]: change keySchema to a more specific type after we
67-
// can extract partition key from keySchema
68-
.add("partition_key", keySchema)
78+
.add("partition_key", extractor.partitionKeySchema)
6979
.add("key_bytes", BinaryType)
7080
.add("value_bytes", BinaryType)
7181
.add("column_family_name", StringType)
@@ -77,6 +87,33 @@ object SchemaUtil {
7787
}
7888
}
7989

90+
/**
91+
* Creates a StatePartitionKeyExtractor for the given operator.
92+
* This is used to extract partition keys from state store keys for state repartitioning.
93+
*/
94+
def getExtractor(
95+
operatorName: String,
96+
keySchema: StructType,
97+
storeName: String,
98+
transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo],
99+
stateFormatVersion: Option[Int]): StatePartitionKeyExtractor = {
100+
val colFamilyName: String =
101+
if (operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME) {
102+
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide).head
103+
} else {
104+
transformWithStateVariableInfoOpt.map(_.stateName)
105+
.getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
106+
}
107+
StatePartitionKeyExtractorFactory.create(
108+
operatorName,
109+
keySchema,
110+
storeName = storeName,
111+
colFamilyName = colFamilyName,
112+
stateFormatVersion = stateFormatVersion,
113+
transformWithStateVariableInfoOpt
114+
)
115+
}
116+
80117
def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow), partition: Int): InternalRow = {
81118
val row = new GenericInternalRow(3)
82119
row.update(0, pair._1)

0 commit comments

Comments
 (0)