diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandle.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandle.scala new file mode 100644 index 000000000000..5088865c7497 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandle.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.testing + +import java.util.UUID + +import scala.collection.mutable +import scala.reflect.ClassTag + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.ImplicitGroupingKeyTracker +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.QueryInfoImpl +import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TTLConfig, ValueState} + +/** In-memory implementation of ValueState. */ +class InMemoryValueState[T] extends ValueState[T] { + private val keyToStateValue = mutable.Map[Any, T]() + + override def exists(): Boolean = + keyToStateValue.contains(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) + + override def get(): T = + keyToStateValue.getOrElse( + ImplicitGroupingKeyTracker.getImplicitKeyOption.get, + null.asInstanceOf[T] + ) + + override def update(newState: T): Unit = + keyToStateValue.put(ImplicitGroupingKeyTracker.getImplicitKeyOption.get, newState) + + override def clear(): Unit = + keyToStateValue.remove(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) +} + +/** In-memory implementation of ListState. */ +class InMemoryListState[T] extends ListState[T] { + private val keyToStateValue = mutable.Map[Any, mutable.ArrayBuffer[T]]() + + override def exists(): Boolean = + keyToStateValue.contains(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) + + private def getList: mutable.ArrayBuffer[T] = { + if (!exists()) { + keyToStateValue.put( + ImplicitGroupingKeyTracker.getImplicitKeyOption.get, + mutable.ArrayBuffer.empty[T] + ) + } + keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get).get + } + + override def get(): Iterator[T] = + if (exists()) getList.iterator else Iterator.empty + + override def put(newState: Array[T]): Unit = + keyToStateValue.put( + ImplicitGroupingKeyTracker.getImplicitKeyOption.get, + mutable.ArrayBuffer.empty[T] ++ newState + ) + + override def appendValue(newState: T): Unit = getList += newState + + override def appendList(newState: Array[T]): Unit = getList ++= newState + + override def clear(): Unit = + keyToStateValue.remove(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) +} + +/** In-memory implementation of MapState. */ +class InMemoryMapState[K, V] extends MapState[K, V] { + private val keyToStateValue = mutable.Map[Any, mutable.HashMap[K, V]]() + + override def exists(): Boolean = + keyToStateValue.contains(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) + + private def getMap: mutable.HashMap[K, V] = { + if (!exists()) { + keyToStateValue.put( + ImplicitGroupingKeyTracker.getImplicitKeyOption.get, + mutable.HashMap.empty[K, V] + ) + } + keyToStateValue(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) + } + + private def getMapIfExists: Option[mutable.HashMap[K, V]] = { + keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) + } + + override def getValue(key: K): V = + getMapIfExists.flatMap(_.get(key)).getOrElse(null.asInstanceOf[V]) + + override def containsKey(key: K): Boolean = getMapIfExists.exists(_.contains(key)) + + override def updateValue(key: K, value: V): Unit = getMap.put(key, value) + + override def iterator(): Iterator[(K, V)] = + getMapIfExists.map(_.iterator).getOrElse(Iterator.empty) + + override def keys(): Iterator[K] = getMapIfExists.map(_.keys.iterator).getOrElse(Iterator.empty) + + override def values(): Iterator[V] = + getMapIfExists.map(_.values.iterator).getOrElse(Iterator.empty) + + override def removeKey(key: K): Unit = getMapIfExists.foreach(_.remove(key)) + + override def clear(): Unit = + keyToStateValue.remove(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) +} + +/** + * In-memory implementation of StatefulProcessorHandle. + * + * Doesn't support timers and TTL. Supports directly accessing state. + */ +class InMemoryStatefulProcessorHandle() extends StatefulProcessorHandle { + private val states = mutable.Map[String, Any]() + + override def getValueState[T]( + stateName: String, + valEncoder: Encoder[T], + ttlConfig: TTLConfig + ): ValueState[T] = { + states + .getOrElseUpdate(stateName, new InMemoryValueState[T]()) + .asInstanceOf[InMemoryValueState[T]] + } + + override def getValueState[T: Encoder](stateName: String, ttlConfig: TTLConfig): ValueState[T] = + getValueState(stateName, implicitly[Encoder[T]], ttlConfig) + + override def getListState[T]( + stateName: String, + valEncoder: Encoder[T], + ttlConfig: TTLConfig + ): ListState[T] = { + states + .getOrElseUpdate(stateName, new InMemoryListState[T]()) + .asInstanceOf[InMemoryListState[T]] + } + + override def getListState[T: Encoder](stateName: String, ttlConfig: TTLConfig): ListState[T] = + getListState(stateName, implicitly[Encoder[T]], ttlConfig) + + override def getMapState[K, V]( + stateName: String, + userKeyEnc: Encoder[K], + valEncoder: Encoder[V], + ttlConfig: TTLConfig + ): MapState[K, V] = { + states + .getOrElseUpdate(stateName, new InMemoryMapState[K, V]()) + .asInstanceOf[InMemoryMapState[K, V]] + } + + override def getMapState[K: Encoder, V: Encoder]( + stateName: String, + ttlConfig: TTLConfig): MapState[K, V] = + getMapState(stateName, implicitly[Encoder[K]], implicitly[Encoder[V]], ttlConfig) + + override def getQueryInfo(): QueryInfo = + new QueryInfoImpl(UUID.randomUUID(), UUID.randomUUID(), 0L) + + override def registerTimer(expiryTimestampMs: Long): Unit = + throw new UnsupportedOperationException("Timers are not supported.") + + override def deleteTimer(expiryTimestampMs: Long): Unit = + throw new UnsupportedOperationException("Timers are not supported.") + + override def listTimers(): Iterator[Long] = + throw new UnsupportedOperationException("Timers are not supported.") + + override def deleteIfExists(stateName: String): Unit = states.remove(stateName) + + def setValueState[T](stateName: String, value: T): Unit = { + require(states.contains(stateName), s"State $stateName has not been initialized.") + states(stateName).asInstanceOf[InMemoryValueState[T]].update(value) + } + + def peekValueState[T](stateName: String): Option[T] = { + require(states.contains(stateName), s"State $stateName has not been initialized.") + Option(states(stateName).asInstanceOf[InMemoryValueState[T]].get()) + } + + def setListState[T](stateName: String, value: List[T])(implicit ct: ClassTag[T]): Unit = { + require(states.contains(stateName), s"State $stateName has not been initialized.") + states(stateName).asInstanceOf[InMemoryListState[T]].put(value.toArray) + } + + def peekListState[T](stateName: String): List[T] = { + require(states.contains(stateName), s"State $stateName has not been initialized.") + states(stateName).asInstanceOf[InMemoryListState[T]].get().toList + } + + def setMapState[MK, MV](stateName: String, value: Map[MK, MV]): Unit = { + require(states.contains(stateName), s"State $stateName has not been initialized.") + val mapState = states(stateName).asInstanceOf[InMemoryMapState[MK, MV]] + mapState.clear() + value.foreach { case (k, v) => mapState.updateValue(k, v) } + } + + def peekMapState[MK, MV](stateName: String): Map[MK, MV] = { + require(states.contains(stateName), s"State $stateName has not been initialized.") + states(stateName).asInstanceOf[InMemoryMapState[MK, MV]].iterator().toMap + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala new file mode 100644 index 000000000000..7bd64910a471 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.ImplicitGroupingKeyTracker +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.testing.InMemoryStatefulProcessorHandle + +/** + * Testing utility for transformWithState stateful processors. + * + * This class enables unit testing of StatefulProcessor business logic by simulating the + * behavior of transformWithState. It processes input rows and returns output rows equivalent + * to those that would be produced by the processor in an actual Spark streaming query. + * + * '''Supported:''' + * - Processing input rows and producing output rows via `test()`. + * - Initial state setup via constructor parameter. + * - Direct state manipulation via `setValueState`, `setListState`, `setMapState`. + * - Direct state inspection via `peekValueState`, `peekListState`, `peekMapState`. + * + * '''Not Supported:''' + * - '''Timers''': Only TimeMode.None is supported. If the processor attempts to register or + * use timers (as if in TimeMode.EventTime or TimeMode.ProcessingTime), a NullPointerException + * will be thrown. + * - '''TTL''': State TTL configurations are ignored. All state persists indefinitely. + * + * '''Use Cases:''' + * - '''Primary''': Unit testing business logic in `handleInputRows` implementations. + * - '''Not recommended''': End-to-end testing or performance testing - use actual Spark + * streaming queries for those scenarios. + * + * @param processor the StatefulProcessor to test + * @param initialState initial state for each key as a list of (key, state) tuples + * @tparam K the type of grouping key + * @tparam I the type of input rows + * @tparam O the type of output rows + */ +class TwsTester[K, I, O]( + val processor: StatefulProcessor[K, I, O], + val initialState: List[(K, Any)] = List()) { + private val handle = new InMemoryStatefulProcessorHandle() + processor.setHandle(handle) + processor.init(OutputMode.Append, TimeMode.None) + processor match { + case p: StatefulProcessorWithInitialState[K @unchecked, I @unchecked, O @unchecked, s] => + handleInitialState[s]() + case _ => + } + + private def handleInitialState[S](): Unit = { + val p = processor.asInstanceOf[StatefulProcessorWithInitialState[K, I, O, S]] + initialState.foreach { + case (key, state) => + ImplicitGroupingKeyTracker.setImplicitKey(key) + p.handleInitialState(key, state.asInstanceOf[S], null) + } + } + + /** + * Processes input rows through the stateful processor, grouped by key. + * + * This corresponds to processing one microbatch. {@code handleInputRows} will be called once for + * each key that appears in {@code input}. + * + * To simulate real-time mode, call this method repeatedly in a loop, passing a list with a single + * (key, input row) tuple per call. + * + * @param input list of (key, input row) tuples to process + * @return all output rows produced by the processor + */ + def test(input: List[(K, I)]): List[O] = { + var ans: List[O] = List() + for ((key, v) <- input.groupBy(_._1)) { + ImplicitGroupingKeyTracker.setImplicitKey(key) + ans = ans ++ processor.handleInputRows(key, v.map(_._2).iterator, null).toList + } + ans + } + + /** Sets the value state for a given key. */ + def setValueState[T](stateName: String, key: K, value: T): Unit = { + ImplicitGroupingKeyTracker.setImplicitKey(key) + handle.setValueState[T](stateName, value) + } + + /** Retrieves the value state for a given key. */ + def peekValueState[T](stateName: String, key: K): Option[T] = { + ImplicitGroupingKeyTracker.setImplicitKey(key) + handle.peekValueState[T](stateName) + } + + /** Sets the list state for a given key. */ + def setListState[T](stateName: String, key: K, value: List[T])(implicit ct: ClassTag[T]): Unit = { + ImplicitGroupingKeyTracker.setImplicitKey(key) + handle.setListState[T](stateName, value) + } + + /** Retrieves the list state for a given key. */ + def peekListState[T](stateName: String, key: K): List[T] = { + ImplicitGroupingKeyTracker.setImplicitKey(key) + handle.peekListState[T](stateName) + } + + /** Sets the map state for a given key. */ + def setMapState[MK, MV](stateName: String, key: K, value: Map[MK, MV]): Unit = { + ImplicitGroupingKeyTracker.setImplicitKey(key) + handle.setMapState[MK, MV](stateName, value) + } + + /** Retrieves the map state for a given key. */ + def peekMapState[MK, MV](stateName: String, key: K): Map[MK, MV] = { + ImplicitGroupingKeyTracker.setImplicitKey(key) + handle.peekMapState[MK, MV](stateName) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala new file mode 100644 index 000000000000..37f2ec270784 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.processors._ + +/** Test suite for TwsTester utility class. */ +class TwsTesterSuite extends SparkFunSuite { + + test("TwsTester should correctly test RunningCountProcessor") { + val input: List[(String, String)] = List( + ("key1", "a"), + ("key2", "b"), + ("key1", "c"), + ("key2", "b"), + ("key1", "c"), + ("key1", "c"), + ("key3", "q") + ) + val tester = new TwsTester(new RunningCountProcessor[String]()) + val ans1: List[(String, Long)] = tester.test(input) + assert(ans1.sorted == List(("key1", 4L), ("key2", 2L), ("key3", 1L)).sorted) + + assert(tester.peekValueState[Long]("count", "key1").get == 4L) + assert(tester.peekValueState[Long]("count", "key2").get == 2L) + assert(tester.peekValueState[Long]("count", "key3").get == 1L) + assert(tester.peekValueState[Long]("count", "key4").isEmpty) + + val ans2 = tester.test(List(("key1", "q"))) + assert(ans2 == List(("key1", 5L))) + assert(tester.peekValueState[Long]("count", "key1").get == 5L) + assert(tester.peekValueState[Long]("count", "key2").get == 2L) + + val ans3 = tester.test(List(("key1", "a"), ("key2", "a"))) + assert(ans3.sorted == List(("key1", 6L), ("key2", 3L))) + } + + test("TwsTester should allow direct access to ValueState") { + val processor = new RunningCountProcessor[String]() + val tester = new TwsTester[String, String, (String, Long)](processor) + tester.setValueState[Long]("count", "foo", 5) + tester.test(List(("foo", "a"))) + assert(tester.peekValueState[Long]("count", "foo").get == 6L) + } + + test("TwsTester should correctly test TopKProcessor") { + val input: List[(String, (String, Double))] = List( + ("key2", ("c", 30.0)), + ("key2", ("d", 40.0)), + ("key1", ("b", 2.0)), + ("key1", ("c", 3.0)), + ("key2", ("a", 10.0)), + ("key2", ("b", 20.0)), + ("key3", ("a", 100.0)), + ("key1", ("a", 1.0)) + ) + val tester = new TwsTester(new TopKProcessor(2)) + val ans1 = tester.test(input) + assert( + ans1.sorted == List( + ("key1", 2.0), + ("key1", 3.0), + ("key2", 30.0), + ("key2", 40.0), + ("key3", 100.0) + ) + ) + assert(tester.peekListState[Double]("topK", "key1") == List(3.0, 2.0)) + assert(tester.peekListState[Double]("topK", "key2") == List(40.0, 30.0)) + assert(tester.peekListState[Double]("topK", "key3") == List(100.0)) + assert(tester.peekListState[Double]("topK", "key4").isEmpty) + + val ans2 = tester.test(List(("key1", ("a", 10.0)))) + assert(ans2.sorted == List(("key1", 3.0), ("key1", 10.0))) + assert(tester.peekListState[Double]("topK", "key1") == List(10.0, 3.0)) + } + + test("TwsTester should allow direct access to ListState") { + val tester = new TwsTester(new TopKProcessor(2)) + tester.setListState("topK", "a", List(6.0, 5.0)) + tester.setListState("topK", "b", List(8.0, 7.0)) + tester.test(List(("a", ("", 10.0)))) + tester.test(List(("b", ("", 7.5)))) + tester.test(List(("c", ("", 1.0)))) + + assert(tester.peekListState[Double]("topK", "a") == List(10.0, 6.0)) + assert(tester.peekListState[Double]("topK", "b") == List(8.0, 7.5)) + assert(tester.peekListState[Double]("topK", "c") == List(1.0)) + assert(tester.peekListState[Double]("topK", "d") == List()) + } + + test("TwsTester should correctly test WordFrequencyProcessor") { + val input: List[(String, (String, String))] = List( + ("user1", ("", "hello")), + ("user1", ("", "world")), + ("user1", ("", "hello")), + ("user2", ("", "hello")), + ("user2", ("", "spark")), + ("user1", ("", "world")) + ) + val tester = new TwsTester(new WordFrequencyProcessor()) + val ans1 = tester.test(input) + + assert( + ans1.sorted == List( + ("user1", "hello", 1L), + ("user1", "hello", 2L), + ("user1", "world", 1L), + ("user1", "world", 2L), + ("user2", "hello", 1L), + ("user2", "spark", 1L) + ).sorted + ) + + // Check state using peekMapState + assert( + tester.peekMapState[String, Long]("frequencies", "user1") == Map("hello" -> 2L, "world" -> 2L) + ) + assert( + tester.peekMapState[String, Long]("frequencies", "user2") == Map("hello" -> 1L, "spark" -> 1L) + ) + assert(tester.peekMapState[String, Long]("frequencies", "user3") == Map()) + assert(tester.peekMapState[String, Long]("frequencies", "user3").isEmpty) + + // Process more data for user1 + val ans2 = tester.test(List(("user1", ("", "hello")), ("user1", ("", "test")))) + assert(ans2.sorted == List(("user1", "hello", 3L), ("user1", "test", 1L)).sorted) + assert( + tester.peekMapState[String, Long]("frequencies", "user1") == Map( + "hello" -> 3L, + "world" -> 2L, + "test" -> 1L + ) + ) + } + + test("TwsTester should allow direct access to MapState") { + val tester = new TwsTester(new WordFrequencyProcessor()) + + // Set initial state directly + tester.setMapState("frequencies", "user1", Map("hello" -> 5L, "world" -> 3L)) + tester.setMapState("frequencies", "user2", Map("spark" -> 10L)) + + // Process new words + tester.test(List(("user1", ("", "hello")))) + tester.test(List(("user1", ("", "goodbye")))) + tester.test(List(("user2", ("", "spark")))) + tester.test(List(("user3", ("", "new")))) + + // Verify updated state + assert( + tester.peekMapState[String, Long]("frequencies", "user1") == Map( + "hello" -> 6L, + "world" -> 3L, + "goodbye" -> 1L + ) + ) + assert(tester.peekMapState[String, Long]("frequencies", "user2") == Map("spark" -> 11L)) + assert(tester.peekMapState[String, Long]("frequencies", "user3") == Map("new" -> 1L)) + assert(tester.peekMapState[String, Long]("frequencies", "user4") == Map()) + } + + test("TwsTester can be used to test step function") { + val processor = new RunningCountProcessor[String]() + val tester = new TwsTester(processor) + + // Example of helper function using TwsTester to inspect how processing a single row changes + // state. + def testStepFunction(key: String, inputRow: String, stateIn: Long): Long = { + tester.setValueState[Long]("count", key, stateIn) + tester.test(List((key, inputRow))) + tester.peekValueState("count", key).get + } + + assert(testStepFunction("key1", "a", 10L) == 11L) + } + + test("TwsTester should call handleInitialState") { + val processor = new RunningCountProcessor[String]() + val tester = new TwsTester(processor, initialState = List(("a", 10L), ("b", 20L))) + assert(tester.peekValueState[Long]("count", "a").get == 10L) + assert(tester.peekValueState[Long]("count", "b").get == 20L) + + val ans = tester.test(List(("a", "a"), ("c", "c"))) + assert(ans == List(("a", 11L), ("c", 1L))) + } + + test("TwsTester should test RunningCountProcessor row-by-row") { + val tester = new TwsTester(new RunningCountProcessor[String]()) + + // Example of helper function to test how TransformWithState processes rows one-by-one, which + // is can be used to simulate real-time mode. + def testRowByRow(input: List[(String, String)]): List[(String, Long)] = { + input.flatMap(row => tester.test(List(row))) + } + + val input: List[(String, String)] = List( + ("key1", "a"), + ("key2", "b"), + ("key1", "c"), + ("key2", "b"), + ("key1", "c"), + ("key1", "c"), + ("key3", "q") + ) + val ans: List[(String, Long)] = testRowByRow(input) + assert( + ans == List( + ("key1", 1L), + ("key2", 1L), + ("key1", 2L), + ("key2", 2L), + ("key1", 3L), + ("key1", 4L), + ("key3", 1L) + ) + ) + } + + test("TwsTester should exercise all state methods") { + val tester = new TwsTester(new AllMethodsTestProcessor()) + val results = tester.test( + List( + ("k", "value-exists"), // false + ("k", "value-set"), // set to 42 + ("k", "value-exists"), // true + ("k", "value-clear"), // clear + ("k", "value-exists"), // false again + ("k", "list-exists"), // false + ("k", "list-append"), // append a, b + ("k", "list-exists"), // true + ("k", "list-append-array"), // append c, d + ("k", "list-get"), // a,b,c,d + ("k", "map-exists"), // false + ("k", "map-add"), // add x=1, y=2, z=3 + ("k", "map-exists"), // true + ("k", "map-keys"), // x,y,z + ("k", "map-values"), // 1,2,3 + ("k", "map-iterator"), // x=1,y=2,z=3 + ("k", "map-remove"), // remove y + ("k", "map-keys"), // x,z + ("k", "map-clear"), // clear map + ("k", "map-exists") // false + ) + ) + + assert( + results == List( + ("k", "value-exists:false"), + ("k", "value-set:done"), + ("k", "value-exists:true"), + ("k", "value-clear:done"), + ("k", "value-exists:false"), + ("k", "list-exists:false"), + ("k", "list-append:done"), + ("k", "list-exists:true"), + ("k", "list-append-array:done"), + ("k", "list-get:a,b,c,d"), + ("k", "map-exists:false"), + ("k", "map-add:done"), + ("k", "map-exists:true"), + ("k", "map-keys:x,y,z"), + ("k", "map-values:1,2,3"), + ("k", "map-iterator:x=1,y=2,z=3"), + ("k", "map-remove:done"), + ("k", "map-keys:x,z"), + ("k", "map-clear:done"), + ("k", "map-exists:false") + ) + ) + } +} + +/** + * Integration test suite that compares TwsTester results with real streaming execution. + * Thread auditing is disabled because this suite runs actual streaming queries with RocksDB + * and shuffle operations, which spawn daemon threads (e.g., Netty boss/worker threads, + * file client threads, ForkJoinPool workers, and cleaner threads) that shut down + * asynchronously after SparkContext.stop(). + */ +class TwsTesterFuzzTestSuite extends StreamTest { + import testImplicits._ + + // Disable thread auditing for this suite since it runs integration tests with + // real streaming queries that create asynchronously-stopped threads + override protected val enableAutoThreadAudit = false + + /** + * Asserts that {@code tester} is equivalent to streaming query transforming {@code inputStream} + * to {@code result}, when both are fed with data from {@code batches}. + */ + private def checkTwsTesterEndToEnd[ + K: org.apache.spark.sql.Encoder, + I: org.apache.spark.sql.Encoder, + O: org.apache.spark.sql.Encoder]( + tester: TwsTester[K, I, O], + batches: List[List[(K, I)]], + inputStream: MemoryStream[(K, I)], + result: Dataset[O]): Unit = { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "5" + ) { + val expectedResults: List[List[O]] = batches.map(batch => tester.test(batch)).toList + assert(batches.size == expectedResults.size) + + val actions: Seq[StreamAction] = (batches zip expectedResults).flatMap { + case (batch, expected) => + Seq( + AddData(inputStream, batch: _*), + CheckNewAnswer(expected.head, expected.tail: _*) + ) + } :+ StopStream + testStream(result, OutputMode.Append())(actions: _*) + } + } + + /** + * Asserts that {@code tester} processes given {@code batches} in the same way as Spark streaming + * query with {@code transformWithState} would. + */ + private def checkTwsTester[ + K: org.apache.spark.sql.Encoder, + I: org.apache.spark.sql.Encoder, + O: org.apache.spark.sql.Encoder]( + processor: StatefulProcessor[K, I, O], + batches: List[List[(K, I)]]): Unit = { + implicit val tupleEncoder = org.apache.spark.sql.Encoders.tuple( + implicitly[org.apache.spark.sql.Encoder[K]], + implicitly[org.apache.spark.sql.Encoder[I]] + ) + val inputStream = MemoryStream[(K, I)] + val result = inputStream + .toDS() + .groupByKey(_._1) + .mapValues(_._2) + .transformWithState(processor, TimeMode.None(), OutputMode.Append()) + checkTwsTesterEndToEnd(new TwsTester(processor), batches, inputStream, result) + } + + private def split[T](xs: List[T], numParts: Int): List[List[T]] = { + require(numParts > 0 && xs.size % numParts == 0) + val partSize = xs.size / numParts + (0 until numParts).map { i => + xs.slice(i * partSize, (i + 1) * partSize) + }.toList + } + + test("fuzz test with RunningCountProcessor") { + val random = new scala.util.Random(0) + val input = List.fill(1000) { + (s"key${random.nextInt(10)}", random.alphanumeric.take(5).mkString) + } + val processor = new RunningCountProcessor[String]() + checkTwsTester(processor, split(input, 2)) + } + + test("fuzz test with TopKProcessor") { + val random = new scala.util.Random(0) + val input = List.fill(1000) { + ( + s"key${random.nextInt(10)}", + (random.alphanumeric.take(5).mkString, random.nextDouble() * 100) + ) + } + val processor = new TopKProcessor(5) + checkTwsTester(processor, split(input, 2)) + } + + test("fuzz test with WordFrequencyProcessor") { + val random = new scala.util.Random(0) + val words = Array("spark", "scala", "flink", "kafka", "hadoop", "hive", "presto", "trino") + val input = List.fill(1000) { + (s"key${random.nextInt(10)}", ("", words(random.nextInt(words.length)))) + } + val processor = new WordFrequencyProcessor() + checkTwsTester(processor, split(input, 2)) + } + + test("fuzz test for AllMethodsTestProcessor") { + val random = new scala.util.Random(0) + val commands = Array( + "value-exists", + "value-set", + "value-clear", + "list-exists", + "list-append", + "list-append-array", + "list-get", + "map-exists", + "map-add", + "map-keys", + "map-values", + "map-iterator", + "map-remove", + "map-clear" + ) + val input = List.fill(500) { + (s"key${random.nextInt(5)}", commands(random.nextInt(commands.length))) + } + val processor = new AllMethodsTestProcessor() + checkTwsTester(processor, split(input, 2)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala new file mode 100644 index 000000000000..ca3c77aa8ded --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming.processors + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.streaming.{ListState, MapState, OutputMode, StatefulProcessor, TimeMode, TimerValues, TTLConfig, ValueState} + +/** Test processor that exercises all state methods for coverage testing. */ +class AllMethodsTestProcessor extends StatefulProcessor[String, String, (String, String)] { + + @transient private var valueState: ValueState[Int] = _ + @transient private var listState: ListState[String] = _ + @transient private var mapState: MapState[String, Int] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + valueState = getHandle.getValueState[Int]("value", Encoders.scalaInt, TTLConfig.NONE) + listState = getHandle.getListState[String]("list", Encoders.STRING, TTLConfig.NONE) + mapState = + getHandle.getMapState[String, Int]("map", Encoders.STRING, Encoders.scalaInt, TTLConfig.NONE) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues + ): Iterator[(String, String)] = { + val results = scala.collection.mutable.ArrayBuffer[(String, String)]() + + inputRows.foreach { cmd => + cmd match { + case "value-exists" => + results += ((key, s"value-exists:${valueState.exists()}")) + case "value-set" => + valueState.update(42) + results += ((key, "value-set:done")) + case "value-clear" => + valueState.clear() + results += ((key, "value-clear:done")) + case "list-exists" => + results += ((key, s"list-exists:${listState.exists()}")) + case "list-append" => + listState.appendValue("a") + listState.appendValue("b") + results += ((key, "list-append:done")) + case "list-append-array" => + listState.appendList(Array("c", "d")) + results += ((key, "list-append-array:done")) + case "list-get" => + val items = listState.get().toList.mkString(",") + results += ((key, s"list-get:$items")) + case "map-exists" => + results += ((key, s"map-exists:${mapState.exists()}")) + case "map-add" => + mapState.updateValue("x", 1) + mapState.updateValue("y", 2) + mapState.updateValue("z", 3) + results += ((key, "map-add:done")) + case "map-keys" => + val keys = mapState.keys().toList.sorted.mkString(",") + results += ((key, s"map-keys:$keys")) + case "map-values" => + val values = mapState.values().toList.sorted.mkString(",") + results += ((key, s"map-values:$values")) + case "map-iterator" => + val pairs = + mapState.iterator().toList.sortBy(_._1).map(p => s"${p._1}=${p._2}").mkString(",") + results += ((key, s"map-iterator:$pairs")) + case "map-remove" => + mapState.removeKey("y") + results += ((key, "map-remove:done")) + case "map-clear" => + mapState.clear() + results += ((key, "map-clear:done")) + } + } + + results.iterator + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/RunningCountProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/RunningCountProcessor.scala new file mode 100644 index 000000000000..c0c2ddfa5679 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/RunningCountProcessor.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming.processors + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessorWithInitialState, TimeMode, TimerValues, TTLConfig, ValueState} + +/** Test StatefulProcessor implementation that maintains a running count. */ +class RunningCountProcessor[T](ttl: TTLConfig = TTLConfig.NONE) + extends StatefulProcessorWithInitialState[String, T, (String, Long), Long] { + + @transient private var countState: ValueState[Long] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + countState = getHandle.getValueState[Long]("count", Encoders.scalaLong, ttl) + } + + override def handleInitialState( + key: String, + initialState: Long, + timerValues: TimerValues + ): Unit = { + countState.update(initialState) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[T], + timerValues: TimerValues + ): Iterator[(String, Long)] = { + val incoming = inputRows.size + val current = countState.get() + val updated = current + incoming + countState.update(updated) + Iterator.single((key, updated)) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/TopKProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/TopKProcessor.scala new file mode 100644 index 000000000000..0970b5580db7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/TopKProcessor.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming.processors + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.streaming.{ListState, OutputMode, StatefulProcessor, TimeMode, TimerValues, TTLConfig} + +// Input: (key, score) as (String, Double) +// Output: (key, score) as (String, Double) for the top K snapshot each batch +class TopKProcessor(k: Int, ttl: TTLConfig = TTLConfig.NONE) + extends StatefulProcessor[String, (String, Double), (String, Double)] { + + @transient private var topKState: ListState[Double] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + topKState = getHandle.getListState[Double]("topK", Encoders.scalaDouble, ttl) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, Double)], + timerValues: TimerValues + ): Iterator[(String, Double)] = { + // Load existing list into a buffer + val current = ArrayBuffer[Double]() + topKState.get().foreach(current += _) + + // Add new values and recompute top-K + inputRows.foreach { + case (_, score) => + current += score + } + val updatedTopK = current.sorted(Ordering[Double].reverse).take(k) + + // Persist back + topKState.clear() + topKState.put(updatedTopK.toArray) + + // Emit snapshot of top-K for this key + updatedTopK.iterator.map(v => (key, v)) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/WordFrequencyProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/WordFrequencyProcessor.scala new file mode 100644 index 000000000000..02b11b7715e3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/WordFrequencyProcessor.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming.processors + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.streaming.{MapState, OutputMode, StatefulProcessor, TimeMode, TimerValues, TTLConfig} + +// Input: (key, word) as (String, String) +// Output: (key, word, count) as (String, String, Long) for each word in the batch +class WordFrequencyProcessor(ttl: TTLConfig = TTLConfig.NONE) + extends StatefulProcessor[String, (String, String), (String, String, Long)] { + + @transient private var freqState: MapState[String, Long] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + freqState = getHandle + .getMapState[String, Long]("frequencies", Encoders.STRING, Encoders.scalaLong, ttl) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues + ): Iterator[(String, String, Long)] = { + val results = ArrayBuffer[(String, String, Long)]() + + inputRows.foreach { + case (_, word) => + val currentCount = if (freqState.containsKey(word)) { + freqState.getValue(word) + } else { + 0L + } + val updatedCount = currentCount + 1 + freqState.updateValue(word, updatedCount) + results += ((key, word, updatedCount)) + } + + results.iterator + } +} +