Skip to content

Commit c2ef0f2

Browse files
committed
initial commit
1 parent c474783 commit c2ef0f2

File tree

5 files changed

+97
-29
lines changed

5 files changed

+97
-29
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpoint
4444
import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId}
4545
import org.apache.spark.sql.execution.streaming.state.OfflineStateRepartitionErrors
4646
import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
47+
import org.apache.spark.sql.internal.SQLConf
4748
import org.apache.spark.sql.sources.DataSourceRegister
4849
import org.apache.spark.sql.streaming.TimeMode
4950
import org.apache.spark.sql.types.StructType
@@ -75,9 +76,18 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
7576
sourceOptions.resolvedCpLocation,
7677
stateConf.providerClass)
7778
}
78-
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
79+
val (stateStoreReaderInfo, storeMetadata) = getStoreMetadataAndRunChecks(
7980
sourceOptions)
8081

82+
// Extract stateFormatVersion from StateStoreConf for SYMMETRIC_HASH_JOIN operator
83+
val isJoin = (
84+
storeMetadata.head.operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME)
85+
val stateFormatVersion: Int = if (storeMetadata.nonEmpty && isJoin) {
86+
session.conf.get(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION)
87+
} else {
88+
1
89+
}
90+
8191
// The key state encoder spec should be available for all operators except stream-stream joins
8292
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
8393
stateStoreReaderInfo.keyStateEncoderSpecOpt.get
@@ -91,17 +101,26 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
91101
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
92102
stateStoreReaderInfo.stateSchemaProviderOpt,
93103
stateStoreReaderInfo.joinColFamilyOpt,
94-
Option(stateStoreReaderInfo.allColumnFamiliesReaderInfo))
104+
Option(stateStoreReaderInfo.allColumnFamiliesReaderInfo),
105+
Option(stateFormatVersion))
95106
}
96107

97108
override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
98109
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
99110
StateSourceOptions.apply(session, hadoopConf, options))
100111

101-
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
102-
sourceOptions)
112+
val (stateStoreReaderInfo, storeMetadata) = getStoreMetadataAndRunChecks(sourceOptions)
103113
val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf)
104114

115+
// Extract stateFormatVersion from StateStoreConf for SYMMETRIC_HASH_JOIN operator
116+
val stateFormatVersion = if (storeMetadata.nonEmpty &&
117+
(storeMetadata.head.operatorName ==
118+
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME)) {
119+
Some(session.conf.get(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION))
120+
} else {
121+
None
122+
}
123+
105124
val stateCheckpointLocation = sourceOptions.stateCheckpointLocation
106125
try {
107126
val (keySchema, valueSchema) = sourceOptions.joinSide match {
@@ -120,10 +139,18 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
120139
(resultSchema.keySchema, resultSchema.valueSchema)
121140
}
122141

142+
val stateVarInfo: Option[TransformWithStateVariableInfo] = if (
143+
sourceOptions.internalOnlyReadAllColumnFamilies) {
144+
stateStoreReaderInfo.allColumnFamiliesReaderInfo.stateVariableInfos.headOption
145+
} else {
146+
stateStoreReaderInfo.transformWithStateVariableInfoOpt
147+
}
123148
SchemaUtil.getSourceSchema(sourceOptions, keySchema,
124149
valueSchema,
125-
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
126-
stateStoreReaderInfo.stateStoreColFamilySchemaOpt)
150+
stateVarInfo,
151+
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
152+
storeMetadata,
153+
stateFormatVersion)
127154
} catch {
128155
case NonFatal(e) =>
129156
throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e)
@@ -257,7 +284,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
257284
}
258285

259286
private def getStoreMetadataAndRunChecks(sourceOptions: StateSourceOptions):
260-
StateStoreReaderInfo = {
287+
(StateStoreReaderInfo, Array[StateMetadataTableEntry]) = {
261288
val storeMetadata = StateDataSource.getStateStoreMetadata(sourceOptions, hadoopConf)
262289
if (!sourceOptions.internalOnlyReadAllColumnFamilies) {
263290
// skipping runStateVarChecks for StatePartitionAllColumnFamiliesReader because
@@ -362,14 +389,14 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
362389
}
363390
}
364391

365-
StateStoreReaderInfo(
392+
(StateStoreReaderInfo(
366393
keyStateEncoderSpecOpt,
367394
stateStoreColFamilySchemaOpt,
368395
transformWithStateVariableInfoOpt,
369396
stateSchemaProvider,
370397
joinColFamilyOpt,
371398
AllColumnFamiliesReaderInfo(stateStoreColFamilySchemas, stateVariableInfos)
372-
)
399+
), storeMetadata)
373400
}
374401

375402
private def getKeyStateEncoderSpec(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class StatePartitionReaderFactory(
4949
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
5050
stateSchemaProviderOpt: Option[StateSchemaProvider],
5151
joinColFamilyOpt: Option[String],
52-
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo])
52+
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo],
53+
stateFormatVersion: Option[Int])
5354
extends PartitionReaderFactory {
5455

5556
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
@@ -58,7 +59,7 @@ class StatePartitionReaderFactory(
5859
require(allColumnFamiliesReaderInfo.isDefined)
5960
new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf,
6061
stateStoreInputPartition, schema, keyStateEncoderSpec, stateStoreColFamilySchemaOpt,
61-
stateSchemaProviderOpt, allColumnFamiliesReaderInfo.get)
62+
stateSchemaProviderOpt, allColumnFamiliesReaderInfo.get, stateFormatVersion)
6263
} else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
6364
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
6465
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
@@ -268,7 +269,8 @@ class StatePartitionAllColumnFamiliesReader(
268269
keyStateEncoderSpec: KeyStateEncoderSpec,
269270
defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
270271
stateSchemaProviderOpt: Option[StateSchemaProvider],
271-
allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
272+
allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo,
273+
stateFormatVersion: Option[Int])
272274
extends StatePartitionReaderBase(
273275
storeConf,
274276
hadoopConf, partition, schema,

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ class StateScanBuilder(
4747
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
4848
stateSchemaProviderOpt: Option[StateSchemaProvider],
4949
joinColFamilyOpt: Option[String],
50-
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]) extends ScanBuilder {
50+
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo],
51+
stateFormatVersion: Option[Int]) extends ScanBuilder {
5152
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf,
5253
keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
53-
joinColFamilyOpt, allColumnFamiliesReaderInfo)
54+
joinColFamilyOpt, allColumnFamiliesReaderInfo, stateFormatVersion)
5455
}
5556

5657
/** An implementation of [[InputPartition]] for State Store data source. */
@@ -70,7 +71,8 @@ class StateScan(
7071
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
7172
stateSchemaProviderOpt: Option[StateSchemaProvider],
7273
joinColFamilyOpt: Option[String],
73-
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo])
74+
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo],
75+
stateFormatVersion: Option[Int])
7476
extends Scan with Batch {
7577

7678
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
@@ -146,7 +148,7 @@ class StateScan(
146148
case JoinSideValues.none =>
147149
new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema,
148150
keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt,
149-
stateSchemaProviderOpt, joinColFamilyOpt, allColumnFamiliesReaderInfo)
151+
stateSchemaProviderOpt, joinColFamilyOpt, allColumnFamiliesReaderInfo, stateFormatVersion)
150152
}
151153

152154
override def toBatch: Batch = this

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class StateTable(
4646
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
4747
stateSchemaProviderOpt: Option[StateSchemaProvider],
4848
joinColFamilyOpt: Option[String],
49-
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo] = None)
49+
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo] = None,
50+
stateFormatVersion: Option[Int] = None)
5051
extends Table with SupportsRead with SupportsMetadataColumns {
5152

5253
import StateTable._
@@ -88,7 +89,7 @@ class StateTable(
8889
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
8990
new StateScanBuilder(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
9091
stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
91-
joinColFamilyOpt, allColumnFamiliesReaderInfo)
92+
joinColFamilyOpt, allColumnFamiliesReaderInfo, stateFormatVersion)
9293

9394
override def properties(): util.Map[String, String] = Map.empty[String, String].asJava
9495

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2626
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
2727
import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, StateSourceOptions}
28+
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
29+
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorsUtils, StatePartitionKeyExtractorFactory}
30+
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.LeftSide
31+
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
2832
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
2933
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateVariableType._
30-
import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStoreColFamilySchema, UnsafeRowPair}
34+
import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StatePartitionKeyExtractor, StateStore, StateStoreColFamilySchema, UnsafeRowPair}
3135
import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, IntegerType, LongType, MapType, StringType, StructType}
3236
import org.apache.spark.unsafe.types.UTF8String
3337
import org.apache.spark.util.ArrayImplicits._
@@ -49,8 +53,21 @@ object SchemaUtil {
4953
keySchema: StructType,
5054
valueSchema: StructType,
5155
transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo],
52-
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]): StructType = {
53-
if (transformWithStateVariableInfoOpt.isDefined) {
56+
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
57+
storeMetadata: Array[StateMetadataTableEntry],
58+
stateFormatVersion: Option[Int] = None): StructType = {
59+
if (sourceOptions.internalOnlyReadAllColumnFamilies) {
60+
// Extract partition key schema using StatePartitionKeyExtractor
61+
require(storeMetadata.nonEmpty)
62+
val extractor = getExtractor(
63+
storeMetadata.head.operatorName, keySchema, sourceOptions.storeName,
64+
transformWithStateVariableInfoOpt, stateFormatVersion)
65+
new StructType()
66+
.add("partition_key", extractor.partitionKeySchema)
67+
.add("key_bytes", BinaryType)
68+
.add("value_bytes", BinaryType)
69+
.add("column_family_name", StringType)
70+
} else if (transformWithStateVariableInfoOpt.isDefined) {
5471
require(stateStoreColFamilySchemaOpt.isDefined)
5572
generateSchemaForStateVar(transformWithStateVariableInfoOpt.get,
5673
stateStoreColFamilySchemaOpt.get, sourceOptions)
@@ -61,14 +78,6 @@ object SchemaUtil {
6178
.add("key", keySchema)
6279
.add("value", valueSchema)
6380
.add("partition_id", IntegerType)
64-
} else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
65-
new StructType()
66-
// TODO [SPARK-54443]: change keySchema to a more specific type after we
67-
// can extract partition key from keySchema
68-
.add("partition_key", keySchema)
69-
.add("key_bytes", BinaryType)
70-
.add("value_bytes", BinaryType)
71-
.add("column_family_name", StringType)
7281
} else {
7382
new StructType()
7483
.add("key", keySchema)
@@ -77,6 +86,33 @@ object SchemaUtil {
7786
}
7887
}
7988

89+
/**
90+
* Creates a StatePartitionKeyExtractor for the given operator.
91+
* This is used to extract partition keys from state store keys for state repartitioning.
92+
*/
93+
def getExtractor(
94+
operatorName: String,
95+
keySchema: StructType,
96+
storeName: String,
97+
transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo],
98+
stateFormatVersion: Option[Int]): StatePartitionKeyExtractor = {
99+
val colFamilyName: String =
100+
if (operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME) {
101+
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide).head
102+
} else {
103+
transformWithStateVariableInfoOpt.map(_.stateName)
104+
.getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
105+
}
106+
StatePartitionKeyExtractorFactory.create(
107+
operatorName,
108+
keySchema,
109+
storeName = storeName,
110+
colFamilyName = colFamilyName,
111+
stateFormatVersion = stateFormatVersion,
112+
transformWithStateVariableInfoOpt
113+
)
114+
}
115+
80116
def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow), partition: Int): InternalRow = {
81117
val row = new GenericInternalRow(3)
82118
row.update(0, pair._1)

0 commit comments

Comments
 (0)