Skip to content

Commit 370a89c

Browse files
committed
address comment
addres comment 2
1 parent 29827e7 commit 370a89c

File tree

9 files changed

+1653
-1070
lines changed

9 files changed

+1653
-1070
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,6 @@ object SchemaUtil {
4444
}
4545
}
4646

47-
def getScanAllColumnFamiliesSchema(keySchema: StructType): StructType = {
48-
new StructType()
49-
// todo [SPARK-54443]: change keySchema to a more specific type after we
50-
// can extract partition key from keySchema
51-
.add("partition_key", keySchema)
52-
.add("key_bytes", BinaryType)
53-
.add("value_bytes", BinaryType)
54-
.add("column_family_name", StringType)
55-
}
56-
5747
def getSourceSchema(
5848
sourceOptions: StateSourceOptions,
5949
keySchema: StructType,
@@ -72,7 +62,13 @@ object SchemaUtil {
7262
.add("value", valueSchema)
7363
.add("partition_id", IntegerType)
7464
} else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
75-
getScanAllColumnFamiliesSchema(keySchema)
65+
new StructType()
66+
// TODO [SPARK-54443]: change keySchema to a more specific type after we
67+
// can extract partition key from keySchema
68+
.add("partition_key", keySchema)
69+
.add("key_bytes", BinaryType)
70+
.add("value_bytes", BinaryType)
71+
.add("column_family_name", StringType)
7672
} else {
7773
new StructType()
7874
.add("key", keySchema)

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,4 @@ object StateStoreColumnFamilySchemaUtils {
264264
valSchema,
265265
Some(RangeKeyScanStateEncoderSpec(keySchema, Seq(0))))
266266
}
267-
268-
def isInternalColFamily(name: String): Boolean = {
269-
name.startsWith("$")
270-
}
271267
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,22 @@ package org.apache.spark.sql.execution.streaming.state
1919
import java.util.UUID
2020

2121
import scala.collection.MapView
22-
import scala.collection.immutable.HashMap
2322

2423
import org.apache.hadoop.conf.Configuration
2524
import org.apache.hadoop.fs.Path
2625

2726
import org.apache.spark.sql.catalyst.InternalRow
2827
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
28+
import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils
2929
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateStoreColumnFamilySchemaUtils
3030
import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
31+
import org.apache.spark.sql.internal.SQLConf
3132

3233
case class StatePartitionWriterColumnFamilyInfo(
3334
schema: StateStoreColFamilySchema,
3435
// set this to true if state variable is ListType in TransformWithState
3536
useMultipleValuesPerKey: Boolean = false)
37+
3638
/**
3739
* A writer that can directly write binary data to the streaming state store.
3840
*
@@ -52,13 +54,26 @@ class StatePartitionAllColumnFamiliesWriter(
5254
operatorId: Int,
5355
storeName: String,
5456
currentBatchId: Long,
55-
columnFamilyToSchemaMap: HashMap[String, StatePartitionWriterColumnFamilyInfo]) {
57+
colFamilyToWriterInfoMap: Map[String, StatePartitionWriterColumnFamilyInfo],
58+
operatorName: String,
59+
schemaProviderOpt: Option[StateSchemaProvider],
60+
sqlConf: Map[String, String]) {
61+
62+
private def isJoinV3Operator(
63+
operatorName: String, sqlConf: Map[String, String]): Boolean = {
64+
operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME &&
65+
sqlConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key) == "3"
66+
}
67+
5668
private val defaultSchema = {
57-
columnFamilyToSchemaMap.get(StateStore.DEFAULT_COL_FAMILY_NAME) match {
69+
colFamilyToWriterInfoMap.get(StateStore.DEFAULT_COL_FAMILY_NAME) match {
5870
case Some(info) => info.schema
5971
case None =>
72+
assert(isJoinV3Operator(operatorName, sqlConf),
73+
s"Please provide the schema of 'default' column family in StateStoreColFamilySchema" +
74+
s"for operator $operatorName")
6075
// Return a dummy StateStoreColFamilySchema if not found
61-
val placeholderSchema = columnFamilyToSchemaMap.head._2.schema
76+
val placeholderSchema = colFamilyToWriterInfoMap.head._2.schema
6277
StateStoreColFamilySchema(
6378
colFamilyName = "__dummy__",
6479
keySchemaId = 0,
@@ -69,23 +84,23 @@ class StatePartitionAllColumnFamiliesWriter(
6984
}
7085
}
7186

87+
private val useColumnFamilies = colFamilyToWriterInfoMap.size > 1
7288
private val columnFamilyToKeySchemaLenMap: MapView[String, Int] =
73-
columnFamilyToSchemaMap.view.mapValues(_.schema.keySchema.length)
89+
colFamilyToWriterInfoMap.view.mapValues(_.schema.keySchema.length)
7490
private val columnFamilyToValueSchemaLenMap: MapView[String, Int] =
75-
columnFamilyToSchemaMap.view.mapValues(_.schema.valueSchema.length)
91+
colFamilyToWriterInfoMap.view.mapValues(_.schema.valueSchema.length)
7692

7793
protected lazy val provider: StateStoreProvider = {
7894
val stateCheckpointLocation = new Path(targetCpLocation, DIR_NAME_STATE).toString
7995
val stateStoreId = StateStoreId(stateCheckpointLocation,
8096
operatorId, partitionId, storeName)
8197
val stateStoreProviderId = StateStoreProviderId(stateStoreId, UUID.randomUUID())
8298

83-
val useColumnFamilies = columnFamilyToSchemaMap.size > 1
8499
val provider = StateStoreProvider.createAndInit(
85100
stateStoreProviderId, defaultSchema.keySchema, defaultSchema.valueSchema,
86101
defaultSchema.keyStateEncoderSpec.get,
87102
useColumnFamilies = useColumnFamilies, storeConf, hadoopConf,
88-
useMultipleValuesPerKey = false, stateSchemaProvider = None)
103+
useMultipleValuesPerKey = false, stateSchemaProvider = schemaProviderOpt)
89104
provider
90105
}
91106

@@ -101,23 +116,22 @@ class StatePartitionAllColumnFamiliesWriter(
101116
stateStoreCkptId = None,
102117
loadEmpty = true
103118
)
104-
if (columnFamilyToSchemaMap.size > 1) {
105-
columnFamilyToSchemaMap.foreach { pair =>
119+
if (useColumnFamilies) {
120+
colFamilyToWriterInfoMap.foreach { pair =>
106121
val colFamilyName = pair._1
107122
val cfSchema = pair._2.schema
108123
colFamilyName match {
109124
case StateStore.DEFAULT_COL_FAMILY_NAME => // createAndInit has registered default
110125
case _ =>
111-
val isInternal = StateStoreColumnFamilySchemaUtils.isInternalColFamily(colFamilyName)
112-
113126
require(cfSchema.keyStateEncoderSpec.isDefined,
114127
s"keyStateEncoderSpec must be defined for column family ${cfSchema.colFamilyName}")
128+
val isInternal = StateStoreColumnFamilySchemaUtils.isInternalColFamily(colFamilyName)
115129
store.createColFamilyIfAbsent(
116130
colFamilyName,
117131
cfSchema.keySchema,
118132
cfSchema.valueSchema,
119133
cfSchema.keyStateEncoderSpec.get,
120-
columnFamilyToSchemaMap(colFamilyName).useMultipleValuesPerKey,
134+
pair._2.useMultipleValuesPerKey,
121135
isInternal)
122136
}
123137
}
@@ -159,10 +173,14 @@ class StatePartitionAllColumnFamiliesWriter(
159173
val valueRow = new UnsafeRow(columnFamilyToValueSchemaLenMap(colFamilyName))
160174
valueRow.pointTo(valueBytes, valueBytes.length)
161175

162-
if (columnFamilyToSchemaMap(colFamilyName).useMultipleValuesPerKey) {
176+
if (colFamilyToWriterInfoMap(colFamilyName).useMultipleValuesPerKey) {
163177
// if a column family useMultipleValuesPerKey (e.g. ListType), we will
164178
// write with 1 put followed by merge
165-
stateStore.merge(keyRow, valueRow, colFamilyName)
179+
if (stateStore.keyExists(keyRow, colFamilyName)) {
180+
stateStore.merge(keyRow, valueRow, colFamilyName)
181+
} else {
182+
stateStore.put(keyRow, valueRow, colFamilyName)
183+
}
166184
} else {
167185
stateStore.put(keyRow, valueRow, colFamilyName)
168186
}

0 commit comments

Comments
 (0)