@@ -20,17 +20,18 @@ import java.util.UUID
2020
2121import scala .collection .MapView
2222import scala .collection .immutable .HashMap
23- import scala .collection .mutable .HashSet
2423
2524import org .apache .hadoop .conf .Configuration
2625import org .apache .hadoop .fs .Path
2726
2827import org .apache .spark .sql .catalyst .InternalRow
2928import org .apache .spark .sql .catalyst .expressions .UnsafeRow
29+ import org .apache .spark .sql .execution .streaming .operators .stateful .transformwithstate .StateStoreColumnFamilySchemaUtils
3030import org .apache .spark .sql .execution .streaming .runtime .StreamingCheckpointConstants .DIR_NAME_STATE
3131
3232case class StatePartitionWriterColumnFamilyInfo (
3333 schema : StateStoreColFamilySchema ,
34+ // set this to true if state variable is ListType in TransformWithState
3435 useMultipleValuesPerKey : Boolean = false )
3536/**
3637 * A writer that can directly write binary data to the streaming state store.
@@ -52,24 +53,34 @@ class StatePartitionAllColumnFamiliesWriter(
5253 storeName : String ,
5354 currentBatchId : Long ,
5455 columnFamilyToSchemaMap : HashMap [String , StatePartitionWriterColumnFamilyInfo ]) {
56+ private val dummySchema : StructType =
57+ StructType (Array (StructField (" __dummy__" , NullType )))
5558 private val defaultSchema = {
56- columnFamilyToSchemaMap.getOrElse(
57- StateStore .DEFAULT_COL_FAMILY_NAME ,
58- columnFamilyToSchemaMap.head._2 // join V3 doesn't have default col family
59- ).schema
59+ columnFamilyToSchemaMap.get(StateStore .DEFAULT_COL_FAMILY_NAME ) match {
60+ case Some (info) => info.schema
61+ case None =>
62+ // Return a dummy StateStoreColFamilySchema if not found
63+ StateStoreColFamilySchema (
64+ colFamilyName = " __dummy__" ,
65+ keySchemaId = 0 ,
66+ keySchema = dummySchema,
67+ valueSchemaId = 0 ,
68+ valueSchema = dummySchema,
69+ keyStateEncoderSpec = Option (NoPrefixKeyStateEncoderSpec (dummySchema)))
70+ }
6071 }
6172
6273 private val columnFamilyToKeySchemaLenMap : MapView [String , Int ] =
6374 columnFamilyToSchemaMap.view.mapValues(_.schema.keySchema.length)
6475 private val columnFamilyToValueSchemaLenMap : MapView [String , Int ] =
6576 columnFamilyToSchemaMap.view.mapValues(_.schema.valueSchema.length)
66- private val colFamilyHasWritten : HashSet [String ] = HashSet [String ]()
6777
6878 protected lazy val provider : StateStoreProvider = {
6979 val stateCheckpointLocation = new Path (targetCpLocation, DIR_NAME_STATE ).toString
7080 val stateStoreId = StateStoreId (stateCheckpointLocation,
7181 operatorId, partitionId, storeName)
7282 val stateStoreProviderId = StateStoreProviderId (stateStoreId, UUID .randomUUID())
83+
7384 val useColumnFamilies = columnFamilyToSchemaMap.size > 1
7485 val provider = StateStoreProvider .createAndInit(
7586 stateStoreProviderId, defaultSchema.keySchema, defaultSchema.valueSchema,
@@ -98,7 +109,7 @@ class StatePartitionAllColumnFamiliesWriter(
98109 colFamilyName match {
99110 case StateStore .DEFAULT_COL_FAMILY_NAME => // createAndInit has registered default
100111 case _ =>
101- val isInternal = colFamilyName.startsWith( " $ " )
112+ val isInternal = StateStoreColumnFamilySchemaUtils .isInternalColFamily(colFamilyName )
102113
103114 require(cfSchema.keyStateEncoderSpec.isDefined,
104115 s " keyStateEncoderSpec must be defined for column family ${cfSchema.colFamilyName}" )
@@ -124,7 +135,6 @@ class StatePartitionAllColumnFamiliesWriter(
124135 try {
125136 rows.foreach(row => writeRow(row))
126137 stateStore.commit()
127- colFamilyHasWritten.clear()
128138 } finally {
129139 if (! stateStore.hasCommitted) {
130140 stateStore.abort()
@@ -150,12 +160,12 @@ class StatePartitionAllColumnFamiliesWriter(
150160 val valueRow = new UnsafeRow (columnFamilyToValueSchemaLenMap(colFamilyName))
151161 valueRow.pointTo(valueBytes, valueBytes.length)
152162
153- if (columnFamilyToSchemaMap(colFamilyName).useMultipleValuesPerKey
154- && colFamilyHasWritten(colFamilyName)) {
163+ if (columnFamilyToSchemaMap(colFamilyName).useMultipleValuesPerKey) {
164+ // if a column family useMultipleValuesPerKey (e.g. ListType), we will
165+ // write with 1 put followed by merge
155166 stateStore.merge(keyRow, valueRow, colFamilyName)
156167 } else {
157168 stateStore.put(keyRow, valueRow, colFamilyName)
158- colFamilyHasWritten.add(colFamilyName)
159169 }
160170 }
161171}
0 commit comments