Skip to content

Commit 8218b56

Browse files
committed
address comment
1 parent 370a89c commit 8218b56

File tree

9 files changed

+459
-638
lines changed

9 files changed

+459
-638
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,11 +1256,11 @@ object SymmetricHashJoinStateManager {
12561256

12571257
private[streaming] sealed trait StateStoreType
12581258

1259-
private[streaming] case object KeyToNumValuesType extends StateStoreType {
1259+
private[sql] case object KeyToNumValuesType extends StateStoreType {
12601260
override def toString(): String = "keyToNumValues"
12611261
}
12621262

1263-
private[streaming] case object KeyWithIndexToValueType extends StateStoreType {
1263+
private[sql] case object KeyWithIndexToValueType extends StateStoreType {
12641264
override def toString(): String = "keyWithIndexToValue"
12651265
}
12661266

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,19 @@ class StatePartitionAllColumnFamiliesWriter(
5757
colFamilyToWriterInfoMap: Map[String, StatePartitionWriterColumnFamilyInfo],
5858
operatorName: String,
5959
schemaProviderOpt: Option[StateSchemaProvider],
60-
sqlConf: Map[String, String]) {
60+
sqlConf: SQLConf) {
6161

6262
private def isJoinV3Operator(
63-
operatorName: String, sqlConf: Map[String, String]): Boolean = {
63+
operatorName: String, sqlConf: SQLConf): Boolean = {
6464
operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME &&
65-
sqlConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key) == "3"
65+
sqlConf.getConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION) == 3
6666
}
6767

6868
private val defaultSchema = {
6969
colFamilyToWriterInfoMap.get(StateStore.DEFAULT_COL_FAMILY_NAME) match {
7070
case Some(info) => info.schema
7171
case None =>
72+
// joinV3 operator doesn't have default column family schema
7273
assert(isJoinV3Operator(operatorName, sqlConf),
7374
s"Please provide the schema of 'default' column family in StateStoreColFamilySchema" +
7475
s"for operator $operatorName")
@@ -173,14 +174,11 @@ class StatePartitionAllColumnFamiliesWriter(
173174
val valueRow = new UnsafeRow(columnFamilyToValueSchemaLenMap(colFamilyName))
174175
valueRow.pointTo(valueBytes, valueBytes.length)
175176

176-
if (colFamilyToWriterInfoMap(colFamilyName).useMultipleValuesPerKey) {
177-
// if a column family useMultipleValuesPerKey (e.g. ListType), we will
178-
// write with 1 put followed by merge
179-
if (stateStore.keyExists(keyRow, colFamilyName)) {
180-
stateStore.merge(keyRow, valueRow, colFamilyName)
181-
} else {
182-
stateStore.put(keyRow, valueRow, colFamilyName)
183-
}
177+
// if a column family useMultipleValuesPerKey (e.g. ListType), we will
178+
// write with 1 put followed by merge
179+
if (colFamilyToWriterInfoMap(colFamilyName).useMultipleValuesPerKey &&
180+
stateStore.keyExists(keyRow, colFamilyName)) {
181+
stateStore.merge(keyRow, valueRow, colFamilyName)
184182
} else {
185183
stateStore.put(keyRow, valueRow, colFamilyName)
186184
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala

Lines changed: 21 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.datasources.v2.state
1919
import java.sql.Timestamp
2020

2121
import org.apache.spark.sql.{DataFrame, Dataset}
22+
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
23+
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
2224
import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
23-
import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, StateStore}
25+
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore}
2426
import org.apache.spark.sql.functions._
2527
import org.apache.spark.sql.internal.SQLConf
2628
import org.apache.spark.sql.streaming._
27-
import org.apache.spark.sql.streaming.util.StreamManualClock
29+
import org.apache.spark.sql.streaming.util.{ColumnFamilyMetadata, StreamManualClock}
2830
import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, NullType, StringType, StructField, StructType, TimestampType}
2931

3032
trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest {
@@ -459,12 +461,6 @@ case class SessionUpdate(
459461
numEvents: Int,
460462
expired: Boolean)
461463

462-
case class ColumnFamilyMetadata(
463-
keySchema: StructType,
464-
valueSchema: StructType,
465-
encoderSpec: KeyStateEncoderSpec,
466-
useMultipleValuePerKey: Boolean = false)
467-
468464
// Utility for runCompositeKeyStreamingAggregationQuery
469465
// todo: Move runCompositeKeyStreamingAggregationQuery to this class
470466
object CompositeKeyAggregationTestUtils {
@@ -568,10 +564,6 @@ object SimpleAggregationTestUtils {
568564
(metadata.keySchema, metadata.valueSchema)
569565
}
570566

571-
/**
572-
* @param stateVersion The state format version:
573-
* @return ColumnFamilyMetadata including schema and KeyEncoderSpec
574-
*/
575567
def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
576568
val keySchema = StructType(Array(
577569
StructField("groupKey", IntegerType, nullable = false)
@@ -606,19 +598,11 @@ object SimpleAggregationTestUtils {
606598
*/
607599
object FlatMapGroupsWithStateTestUtils {
608600

609-
/**
610-
* @param stateVersion The state format version:
611-
* @return A tuple of (keySchema, valueSchema)
612-
*/
613601
def getSchemas(stateVersion: Int): (StructType, StructType) = {
614602
val metadata = getSchemasWithMetadata(stateVersion)
615603
(metadata.keySchema, metadata.valueSchema)
616604
}
617605

618-
/**
619-
* @param stateVersion The state format version
620-
* @return ColumnFamilyMetadata with schema and KeyEncoderSpec
621-
*/
622606
def getSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
623607
val keySchema = StructType(Array(
624608
StructField("value", StringType, nullable = true)
@@ -654,17 +638,11 @@ object FlatMapGroupsWithStateTestUtils {
654638
*/
655639
object SessionWindowTestUtils {
656640

657-
/**
658-
* @return A tuple of (keySchema, valueSchema)
659-
*/
660641
def getSchemas(): (StructType, StructType) = {
661642
val metadata = getSchemasWithMetadata()
662643
(metadata.keySchema, metadata.valueSchema)
663644
}
664645

665-
/**
666-
* @return ColumnFamilyMetadata with schema and KeyEncoderSpec
667-
*/
668646
def getSchemasWithMetadata(): ColumnFamilyMetadata = {
669647
val keySchema = StructType(Array(
670648
StructField("sessionId", StringType, nullable = false),
@@ -687,33 +665,23 @@ object SessionWindowTestUtils {
687665
* Test utility object providing schema definitions and constants for runStreamStreamJoinQuery
688666
*/
689667
object StreamStreamJoinTestUtils {
690-
// Column family names for keyToNumValues stores
691-
val KEY_TO_NUM_VALUES_LEFT = "left-keyToNumValues"
692-
val KEY_TO_NUM_VALUES_RIGHT = "right-keyToNumValues"
693-
val KEY_TO_NUM_VALUES_ALL: Seq[String] = Seq(
694-
KEY_TO_NUM_VALUES_LEFT,
695-
KEY_TO_NUM_VALUES_RIGHT
696-
)
697-
698-
// Column family names for keyWithIndexToValue stores
699-
val KEY_WITH_INDEX_LEFT = "left-keyWithIndexToValue"
700-
val KEY_WITH_INDEX_RIGHT = "right-keyWithIndexToValue"
701-
val KEY_WITH_INDEX_ALL: Seq[String] = Seq(
702-
KEY_WITH_INDEX_LEFT,
703-
KEY_WITH_INDEX_RIGHT
704-
)
705-
706-
/**
707-
* @return A tuple of (keySchema, valueSchema)
708-
*/
668+
// All state store names from SymmetricHashJoinStateManager
669+
private val allStoreNames: Seq[String] =
670+
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
671+
672+
// Column family names for keyToNumValues stores (derived from allStateStoreNames)
673+
val KEY_TO_NUM_VALUES_ALL: Seq[String] =
674+
allStoreNames.filter(_.endsWith(SymmetricHashJoinStateManager.KeyToNumValuesType.toString))
675+
676+
// Column family names for keyWithIndexToValue stores (derived from allStateStoreNames)
677+
val KEY_WITH_INDEX_ALL: Seq[String] =
678+
allStoreNames.filter(_.endsWith(SymmetricHashJoinStateManager.KeyWithIndexToValueType.toString))
679+
709680
def getKeyToNumValuesSchemas(): (StructType, StructType) = {
710681
val metadata = getKeyToNumValuesSchemasWithMetadata()
711682
(metadata.keySchema, metadata.valueSchema)
712683
}
713684

714-
/**
715-
* @return ColumnFamilyMetadata with schema and KeyEncoderSpec
716-
*/
717685
def getKeyToNumValuesSchemasWithMetadata(): ColumnFamilyMetadata = {
718686
val keySchema = StructType(Array(
719687
StructField("key", IntegerType)
@@ -725,19 +693,11 @@ object StreamStreamJoinTestUtils {
725693
ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
726694
}
727695

728-
/**
729-
* @param stateVersion The join state format version:
730-
* @return A tuple of (keySchema, valueSchema)
731-
*/
732696
def getKeyWithIndexToValueSchemas(stateVersion: Int): (StructType, StructType) = {
733697
val metadata = getKeyWithIndexToValueSchemasWithMetadata(stateVersion)
734698
(metadata.keySchema, metadata.valueSchema)
735699
}
736700

737-
/**
738-
* @param stateVersion The state format version
739-
* @return ColumnFamilyMetadata with schema and KeyEncoderSpec
740-
*/
741701
def getKeyWithIndexToValueSchemasWithMetadata(stateVersion: Int): ColumnFamilyMetadata = {
742702
val keySchema = StructType(Array(
743703
StructField("key", IntegerType, nullable = false),
@@ -761,37 +721,23 @@ object StreamStreamJoinTestUtils {
761721
ColumnFamilyMetadata(keySchema, valueSchema, encoderSpec)
762722
}
763723

764-
/**
765-
* Returns all schemas for stream-stream join V3 (multi-column family) in legacy 2-tuple format.
766-
* V3 uses a single state store with multiple column families instead of separate stores.
767-
*
768-
* @return Map of column family name to (keySchema, valueSchema)
769-
*/
770724
def getJoinV3ColumnSchemaMap(): Map[String, (StructType, StructType)] = {
771725
getJoinV3ColumnSchemaMapWithMetadata().view.mapValues { metadata =>
772726
(metadata.keySchema, metadata.valueSchema)
773727
}.toMap
774728
}
775729

776-
/**
777-
* @return Map of column family name to ColumnFamilyMetadata
778-
*/
779730
def getJoinV3ColumnSchemaMapWithMetadata(): Map[String, ColumnFamilyMetadata] = {
780731
val (keyToNumKeySchema, keyToNumValueSchema) = getKeyToNumValuesSchemas()
781732
val (keyWithIndexKeySchema, keyWithIndexValueSchema) = getKeyWithIndexToValueSchemas(3)
782733

783734
val keyToNumEncoderSpec = NoPrefixKeyStateEncoderSpec(keyToNumKeySchema)
784735
val keyWithIndexEncoderSpec = NoPrefixKeyStateEncoderSpec(keyWithIndexKeySchema)
785736

786-
Map(
787-
KEY_TO_NUM_VALUES_LEFT -> ColumnFamilyMetadata(
788-
keyToNumKeySchema, keyToNumValueSchema, keyToNumEncoderSpec),
789-
KEY_TO_NUM_VALUES_RIGHT -> ColumnFamilyMetadata(
790-
keyToNumKeySchema, keyToNumValueSchema, keyToNumEncoderSpec),
791-
KEY_WITH_INDEX_LEFT -> ColumnFamilyMetadata(
792-
keyWithIndexKeySchema, keyWithIndexValueSchema, keyWithIndexEncoderSpec),
793-
KEY_WITH_INDEX_RIGHT -> ColumnFamilyMetadata(
794-
keyWithIndexKeySchema, keyWithIndexValueSchema, keyWithIndexEncoderSpec)
795-
)
737+
KEY_TO_NUM_VALUES_ALL.map(name => name -> ColumnFamilyMetadata(
738+
keyToNumKeySchema, keyToNumValueSchema, keyToNumEncoderSpec)).toMap ++
739+
KEY_WITH_INDEX_ALL.map(name => name -> ColumnFamilyMetadata(
740+
keyWithIndexKeySchema, keyWithIndexValueSchema, keyWithIndexEncoderSpec)).toMap
741+
796742
}
797743
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ import org.apache.spark.sql.{DataFrame, Row}
2525
import org.apache.spark.sql.catalyst.CatalystTypeConverters
2626
import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
28-
import org.apache.spark.sql.execution.datasources.v2.state.utils.{EventTimeTimerProcessor, MultiStateVarProcessor, MultiStateVarProcessorTestUtils, TimerTestUtils}
2928
import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
3029
import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateRepartitionUnsupportedProviderError, StateStore}
3130
import org.apache.spark.sql.functions.{col, count, sum, timestamp_seconds}
3231
import org.apache.spark.sql.internal.SQLConf
3332
import org.apache.spark.sql.streaming.{InputEvent, ListStateTTLProcessor, MapInputEvent, MapStateTTLProcessor, OutputMode, RunningCountStatefulProcessorWithProcTimeTimer, TimeMode, Trigger, TTLConfig, ValueStateTTLProcessor}
3433
import org.apache.spark.sql.streaming.util.{StreamManualClock, TTLProcessorUtils}
34+
import org.apache.spark.sql.streaming.util.{EventTimeTimerProcessor, MultiStateVarProcessor, MultiStateVarProcessorTestUtils, TimerTestUtils}
3535
import org.apache.spark.sql.types.{DataType, NullType, StructField, StructType}
3636

3737
/**

0 commit comments

Comments
 (0)