From dedfcfa78467284959dfe1ac0d67cd34785e685b Mon Sep 17 00:00:00 2001 From: Dmytro Fedoriaka Date: Fri, 21 Nov 2025 20:44:04 +0000 Subject: [PATCH 1/7] Add TwsTester framework and processors - Add TwsTester.scala: Main testing framework for TransformWithState - Add InMemoryStatefulProcessorHandleImpl.scala: In-memory implementation for testing - Add TwsTesterSuite.scala: Test suite for TwsTester framework - Add processors: EventTimeWindow, MultiTimer, RunningCount, SessionTimeout, TopK, WordFrequency --- .../InMemoryStatefulProcessorHandleImpl.scala | 366 +++++++++ .../spark/sql/streaming/TwsTester.scala | 324 ++++++++ .../spark/sql/streaming/TwsTesterSuite.scala | 702 ++++++++++++++++++ .../processors/EventTimeWindowProcessor.scala | 85 +++ .../processors/MultiTimerProcessor.scala | 82 ++ .../processors/RunningCountProcessor.scala | 52 ++ .../processors/SessionTimeoutProcessor.scala | 82 ++ .../streaming/processors/TopKProcessor.scala | 61 ++ .../processors/WordFrequencyProcessor.scala | 58 ++ 9 files changed, 1812 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/EventTimeWindowProcessor.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/MultiTimerProcessor.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/RunningCountProcessor.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/SessionTimeoutProcessor.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/TopKProcessor.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/WordFrequencyProcessor.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala new file mode 100644 index 000000000000..2a76cbb470c7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.testing + +import java.time.Clock +import java.time.Instant +import java.util.UUID + +import scala.collection.mutable +import scala.reflect.ClassTag + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.ImplicitGroupingKeyTracker +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.QueryInfoImpl +import org.apache.spark.sql.streaming.ListState +import org.apache.spark.sql.streaming.MapState +import org.apache.spark.sql.streaming.QueryInfo +import org.apache.spark.sql.streaming.StatefulProcessorHandle +import org.apache.spark.sql.streaming.TimeMode +import org.apache.spark.sql.streaming.TTLConfig +import org.apache.spark.sql.streaming.ValueState + +/** Helper to track expired keys. */ +class TtlTracker(val clock: Clock, ttl: TTLConfig) { + require(!ttl.ttlDuration.isNegative()) + private val keyToLastUpdatedTime = mutable.Map[Any, Instant]() + + def isKeyExpired(): Boolean = { + if (ttl.ttlDuration.isZero()) { + return false + } + val key = ImplicitGroupingKeyTracker.getImplicitKeyOption.get + if (!keyToLastUpdatedTime.contains(key)) { + return false + } + val expiration: Instant = keyToLastUpdatedTime.get(key).get.plus(ttl.ttlDuration) + return expiration.isBefore(clock.instant()) + } + + def onKeyUpdated(): Unit = { + val key = ImplicitGroupingKeyTracker.getImplicitKeyOption.get + keyToLastUpdatedTime.put(key, clock.instant()) + } +} + +class InMemoryValueState[T](clock: Clock, ttl: TTLConfig) extends ValueState[T] { + private val keyToStateValue = mutable.Map[Any, T]() + private val ttlTracker = new TtlTracker(clock, ttl) + + private def getValue: Option[T] = { + if (ttlTracker.isKeyExpired()) { + return None + } + keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) + } + + override def exists(): Boolean = { + getValue.isDefined + } + + override def get(): T = getValue.getOrElse(null.asInstanceOf[T]) + + override def update(newState: T): Unit = { + ttlTracker.onKeyUpdated() + keyToStateValue.put(ImplicitGroupingKeyTracker.getImplicitKeyOption.get, newState) + } + + override def clear(): Unit = { + keyToStateValue.remove(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) + } +} + +class InMemoryListState[T](clock: Clock, ttl: TTLConfig) extends ListState[T] { + private val keyToStateValue = mutable.Map[Any, mutable.ArrayBuffer[T]]() + private val ttlTracker = new TtlTracker(clock, ttl) + + private def getList: Option[mutable.ArrayBuffer[T]] = { + if (ttlTracker.isKeyExpired()) { + return None + } + keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) + } + + override def exists(): Boolean = getList.isDefined + + override def get(): Iterator[T] = { + getList.orElse(Some(mutable.ArrayBuffer.empty[T])).get.iterator + } + + override def put(newState: Array[T]): Unit = { + ttlTracker.onKeyUpdated() + keyToStateValue.put( + ImplicitGroupingKeyTracker.getImplicitKeyOption.get, + mutable.ArrayBuffer.empty[T] ++ newState + ) + } + + override def appendValue(newState: T): Unit = { + ttlTracker.onKeyUpdated() + if (!exists()) { + keyToStateValue.put( + ImplicitGroupingKeyTracker.getImplicitKeyOption.get, + mutable.ArrayBuffer.empty[T] + ) + } + keyToStateValue(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) += newState + } + + override def appendList(newState: Array[T]): Unit = { + ttlTracker.onKeyUpdated() + if (!exists()) { + keyToStateValue.put( + ImplicitGroupingKeyTracker.getImplicitKeyOption.get, + mutable.ArrayBuffer.empty[T] + ) + } + keyToStateValue(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) ++= newState + } + + override def clear(): Unit = { + keyToStateValue.remove(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) + } +} + +class InMemoryMapState[K, V](clock: Clock, ttl: TTLConfig) extends MapState[K, V] { + private val keyToStateValue = mutable.Map[Any, mutable.HashMap[K, V]]() + private val ttlTracker = new TtlTracker(clock, ttl) + + private def getMap: Option[mutable.HashMap[K, V]] = { + if (ttlTracker.isKeyExpired()) { + return None + } + keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) + } + + override def exists(): Boolean = getMap.isDefined + + override def getValue(key: K): V = { + getMap + .orElse(Some(mutable.HashMap.empty[K, V])) + .get + .getOrElse(key, null.asInstanceOf[V]) + } + + override def containsKey(key: K): Boolean = { + getMap + .orElse(Some(mutable.HashMap.empty[K, V])) + .get + .contains(key) + } + + override def updateValue(key: K, value: V): Unit = { + ttlTracker.onKeyUpdated() + if (!exists()) { + keyToStateValue.put( + ImplicitGroupingKeyTracker.getImplicitKeyOption.get, + mutable.HashMap.empty[K, V] + ) + } + + keyToStateValue(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) += (key -> value) + } + + override def iterator(): Iterator[(K, V)] = { + getMap + .orElse(Some(mutable.HashMap.empty[K, V])) + .get + .iterator + } + + override def keys(): Iterator[K] = { + getMap + .orElse(Some(mutable.HashMap.empty[K, V])) + .get + .keys + .iterator + } + + override def values(): Iterator[V] = { + getMap + .orElse(Some(mutable.HashMap.empty[K, V])) + .get + .values + .iterator + } + + override def removeKey(key: K): Unit = { + getMap + .orElse(Some(mutable.HashMap.empty[K, V])) + .get + .remove(key) + } + + override def clear(): Unit = { + keyToStateValue.remove(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) + } +} + +class InMemoryTimers { + private val keyToTimers = mutable.Map[Any, mutable.TreeSet[Long]]() + + def registerTimer(expiryTimestampMs: Long): Unit = { + val groupingKey = ImplicitGroupingKeyTracker.getImplicitKeyOption.get + if (!keyToTimers.contains(groupingKey)) { + keyToTimers.put(groupingKey, mutable.TreeSet[Long]()) + } + keyToTimers(groupingKey).add(expiryTimestampMs) + } + + def deleteTimer(expiryTimestampMs: Long): Unit = { + val groupingKey = ImplicitGroupingKeyTracker.getImplicitKeyOption.get + if (keyToTimers.contains(groupingKey)) { + keyToTimers(groupingKey).remove(expiryTimestampMs) + if (keyToTimers(groupingKey).isEmpty) { + keyToTimers.remove(groupingKey) + } + } + } + + def listTimers(): Iterator[Long] = { + val groupingKey = ImplicitGroupingKeyTracker.getImplicitKeyOption.get + keyToTimers.get(groupingKey) match { + case Some(timers) => timers.iterator + case None => Iterator.empty + } + } + + def getAllKeysWithTimers(): Iterator[Any] = { + keyToTimers.keys.iterator + } +} + +class InMemoryStatefulProcessorHandleImpl( + timeMode: TimeMode, + keyExprEnc: ExpressionEncoder[Any], + clock: Clock = Clock.systemUTC()) + extends StatefulProcessorHandle { + + private val states = mutable.Map[String, Any]() + + override def getValueState[T]( + stateName: String, + valEncoder: Encoder[T], + ttlConfig: TTLConfig + ): ValueState[T] = { + require(!states.contains(stateName), s"State $stateName already defined.") + states + .getOrElseUpdate(stateName, new InMemoryValueState[T](clock, ttlConfig)) + .asInstanceOf[InMemoryValueState[T]] + } + + override def getValueState[T: Encoder](stateName: String, ttlConfig: TTLConfig): ValueState[T] = { + getValueState(stateName, implicitly[Encoder[T]], ttlConfig) + } + + override def getListState[T]( + stateName: String, + valEncoder: Encoder[T], + ttlConfig: TTLConfig + ): ListState[T] = { + require(!states.contains(stateName), s"State $stateName already defined.") + states + .getOrElseUpdate(stateName, new InMemoryListState[T](clock, ttlConfig)) + .asInstanceOf[InMemoryListState[T]] + } + + override def getListState[T: Encoder](stateName: String, ttlConfig: TTLConfig): ListState[T] = { + getListState(stateName, implicitly[Encoder[T]], ttlConfig) + } + + override def getMapState[K, V]( + stateName: String, + userKeyEnc: Encoder[K], + valEncoder: Encoder[V], + ttlConfig: TTLConfig + ): MapState[K, V] = { + require(!states.contains(stateName), s"State $stateName already defined.") + states + .getOrElseUpdate(stateName, new InMemoryMapState[K, V](clock, ttlConfig)) + .asInstanceOf[InMemoryMapState[K, V]] + } + + override def getMapState[K: Encoder, V: Encoder]( + stateName: String, + ttlConfig: TTLConfig + ): MapState[K, V] = { + getMapState(stateName, implicitly[Encoder[K]], implicitly[Encoder[V]], ttlConfig) + } + + override def getQueryInfo(): QueryInfo = { + new QueryInfoImpl(UUID.randomUUID(), UUID.randomUUID(), 0L) + } + + private val timers = new InMemoryTimers() + + override def registerTimer(expiryTimestampMs: Long): Unit = { + require(timeMode != TimeMode.None, "Timers are not supported with TimeMode.None.") + timers.registerTimer(expiryTimestampMs) + } + + override def deleteTimer(expiryTimestampMs: Long): Unit = { + require(timeMode != TimeMode.None, "Timers are not supported with TimeMode.None.") + timers.deleteTimer(expiryTimestampMs) + } + + override def listTimers(): Iterator[Long] = { + require(timeMode != TimeMode.None, "Timers are not supported with TimeMode.None.") + timers.listTimers() + } + + override def deleteIfExists(stateName: String): Unit = { + states.remove(stateName) + } + + def setValueState[T](stateName: String, value: T): Unit = { + require(states.contains(stateName), s"State $stateName has not been initialized.") + states(stateName).asInstanceOf[InMemoryValueState[T]].update(value) + } + + def peekValueState[T](stateName: String): Option[T] = { + require(states.contains(stateName), s"State $stateName has not been initialized.") + val value: T = states(stateName).asInstanceOf[InMemoryValueState[T]].get() + Option(value) + } + + def setListState[T](stateName: String, value: List[T])(implicit ct: ClassTag[T]): Unit = { + require(states.contains(stateName), s"State $stateName has not been initialized.") + states(stateName).asInstanceOf[InMemoryListState[T]].put(value.toArray) + } + + def peekListState[T](stateName: String): List[T] = { + require(states.contains(stateName), s"State $stateName has not been initialized.") + states(stateName).asInstanceOf[InMemoryListState[T]].get().toList + } + + def setMapState[MK, MV](stateName: String, value: Map[MK, MV]): Unit = { + require(states.contains(stateName), s"State $stateName has not been initialized.") + val mapState = states(stateName).asInstanceOf[InMemoryMapState[MK, MV]] + mapState.clear() + value.foreach { case (k, v) => mapState.updateValue(k, v) } + } + + def peekMapState[MK, MV](stateName: String): Map[MK, MV] = { + require(states.contains(stateName), s"State $stateName has not been initialized.") + states(stateName).asInstanceOf[InMemoryMapState[MK, MV]].iterator().toMap + } + + def getAllKeysWithTimers[K](): Iterator[K] = { + timers.getAllKeysWithTimers().map(_.asInstanceOf[K]) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala new file mode 100644 index 000000000000..496c7409eef9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming + +import java.sql.Timestamp +import java.time.{Clock, Duration, Instant, ZoneId} + +import scala.reflect.ClassTag + +import org.apache.spark.sql.catalyst.util.IntervalUtils +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.ImplicitGroupingKeyTracker +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.testing.InMemoryStatefulProcessorHandleImpl +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.{ + ExpiredTimerInfoImpl, + TimerValuesImpl +} + +/** + * Testing utility for transformWithState stateful processors. Provides in-memory state management + * and simplified input processing for unit testing StatefulProcessor implementations. + * + * @param processor the StatefulProcessor to test + * @param clock the clock to use for time-based operations, defaults to system UTC + * @param timeMode time mode that will be passed to transformWithState (defaults to TimeMode.None) + * @param outputMode output mode that will be passed to transformWithState (defaults to + * OutputMode.Append) + * @param initialState initial state for each key + * @tparam K the type of grouping key + * @tparam I the type of input rows + * @tparam O the type of output rows + */ +class TwsTester[K, I, O]( + val processor: StatefulProcessor[K, I, O], + val clock: Clock = Clock.systemUTC(), + val timeMode: TimeMode = TimeMode.None, + val outputMode: OutputMode = OutputMode.Append, + val initialState: List[(K, Any)] = List()) { + private val handle = new InMemoryStatefulProcessorHandleImpl(timeMode, null, clock) + + private var eventTimeFunc: (I => Timestamp) = null + private var delayThresholdMs: Long = 0 + private var currentWatermarkMs: Option[Long] = None + + processor.setHandle(handle) + processor.init(outputMode, timeMode) + + processor match { + case p: StatefulProcessorWithInitialState[K @unchecked, I @unchecked, O @unchecked, s] => + handleInitialState[s]() + case _ => + } + + private def handleInitialState[S](): Unit = { + val timerValues = new TimerValuesImpl(Some(clock.instant().toEpochMilli()), None) + val p = processor.asInstanceOf[StatefulProcessorWithInitialState[K, I, O, S]] + initialState.foreach { + case (key, state) => + ImplicitGroupingKeyTracker.setImplicitKey(key) + p.handleInitialState(key, state.asInstanceOf[S], timerValues) + ImplicitGroupingKeyTracker.removeImplicitKey() + } + } + + /** + * Processes input rows through the stateful processor, grouped by key. + * + * This corresponds to processing one microbatch. {@code handleInputRows} will be called once for + * each key that appears in {@code input}. + * + * @param input list of (key, input row) tuples to process + * @return all output rows produced by the processor + */ + def test(input: List[(K, I)]): List[O] = { + val currentTimeMs: Long = clock.instant().toEpochMilli() + var timerValues = new TimerValuesImpl(Some(currentTimeMs), currentWatermarkMs) + var ans: List[O] = List() + val filteredInput = filterLateEvents(input) + + for ((key, v) <- filteredInput.groupBy(_._1)) { + ImplicitGroupingKeyTracker.setImplicitKey(key) + ans = ans ++ processor.handleInputRows(key, v.map(_._2).iterator, timerValues).toList + ImplicitGroupingKeyTracker.removeImplicitKey() + } + + updateWatermark(input) + timerValues = new TimerValuesImpl(Some(currentTimeMs), currentWatermarkMs) + ans ++ handleExpiredTimers(timerValues) + } + + // Filters late events in EventTime mode. + private def filterLateEvents(input: List[(K, I)]): List[(K, I)] = { + if (timeMode != TimeMode.EventTime || !currentWatermarkMs.isDefined) { + return input + } + require(eventTimeFunc != null, "call withWatermark if timeMode is EventTime") + input.filter { case (_, row) => eventTimeFunc(row).getTime() >= currentWatermarkMs.get } + } + + private def handleExpiredTimers(timerValues: TimerValues): List[O] = { + if (timeMode == TimeMode.None) { + return List() + } + val currentTimeMs: Long = + if (timeMode == TimeMode.EventTime) timerValues.getCurrentWatermarkInMs() + else timerValues.getCurrentProcessingTimeInMs() + + var ans: List[O] = List() + for (key <- handle.getAllKeysWithTimers[K]()) { + ImplicitGroupingKeyTracker.setImplicitKey(key) + val expiredTimers: List[Long] = handle.listTimers().filter(_ <= currentTimeMs).toList + for (timerExpiryTimeMs <- expiredTimers) { + val expiredTimerInfo = new ExpiredTimerInfoImpl(Some(timerExpiryTimeMs)) + ans = ans ++ processor.handleExpiredTimer(key, timerValues, expiredTimerInfo).toList + handle.deleteTimer(timerExpiryTimeMs) + } + ImplicitGroupingKeyTracker.removeImplicitKey() + } + ans + } + + private def updateWatermark(input: List[(K, I)]): Unit = { + if (timeMode != TimeMode.EventTime || input.isEmpty) { + return + } + require(eventTimeFunc != null, "call withWatermark if timeMode is EventTime") + currentWatermarkMs = Some( + math.max( + currentWatermarkMs.getOrElse(0L), + input.map(v => eventTimeFunc(v._2).getTime()).max - delayThresholdMs + ) + ) + } + + /** + * Convenience method to process a single input row for a given key. + * + * @param key the grouping key + * @param inputRow the input row to process + * @return all output rows produced by the processor + */ + def testOneRow(key: K, inputRow: I): List[O] = test(List((key, inputRow))) + + /** + * Processes input rows through the stateful processor, one by one. + * + * This corresponds to running streaming query in real-time mode. {@code handleInputRows} will be + * called once for each row in {@code input}. + * + * @param input list of (key, input row) tuples to process + * @return all output rows produced by the processor + */ + def testRowByRow(input: List[(K, I)]): List[O] = { + var ans: List[O] = List() + for (row <- input) { + ans ++= test(List(row)) + } + ans + } + + /** + * Tests how value state is changed after processing one row. + * + * @param key the grouping key + * @param inputRow the input row to process + * @param stateName the name os value state + * @param stateIn the old value of the value state + * @tparam S the type of value state + * @return output rows produced by the processor and new value of the value state + */ + def testOneRowWithValueState[S]( + key: K, + inputRow: I, + stateName: String, + stateIn: S): (List[O], S) = { + setValueState[S](stateName, key, stateIn) + val outputRows = testOneRow(key, inputRow) + (outputRows, peekValueState[S](stateName, key).get) + } + + /** + * Sets the value state for a given key. + * + * @param stateName the name of the value state variable + * @param key the grouping key + * @param value the value to set + * @tparam T the type of the state value + */ + def setValueState[T](stateName: String, key: K, value: T): Unit = { + ImplicitGroupingKeyTracker.setImplicitKey(key) + handle.setValueState[T](stateName, value) + ImplicitGroupingKeyTracker.removeImplicitKey() + } + + /** + * Retrieves the value state for a given key without modifying it. + * + * @param stateName the name of the value state variable + * @param key the grouping key + * @tparam T the type of the state value + * @return Some(value) if state exists for the key, None otherwise + */ + def peekValueState[T](stateName: String, key: K): Option[T] = { + ImplicitGroupingKeyTracker.setImplicitKey(key) + val result: Option[T] = handle.peekValueState[T](stateName) + ImplicitGroupingKeyTracker.removeImplicitKey() + return result + } + + /** + * Sets the list state for a given key. + * + * @param stateName the name of the list state variable + * @param key the grouping key + * @param value the list of values to set + * @param ct implicit class tag for type T + * @tparam T the type of elements in the list state + */ + def setListState[T](stateName: String, key: K, value: List[T])(implicit ct: ClassTag[T]): Unit = { + ImplicitGroupingKeyTracker.setImplicitKey(key) + handle.setListState[T](stateName, value) + ImplicitGroupingKeyTracker.removeImplicitKey() + } + + /** + * Retrieves the list state for a given key without modifying it. + * + * @param stateName the name of the list state variable + * @param key the grouping key + * @tparam T the type of elements in the list state + * @return the list of values, or an empty list if no state exists for the key + */ + def peekListState[T](stateName: String, key: K): List[T] = { + ImplicitGroupingKeyTracker.setImplicitKey(key) + val result: List[T] = handle.peekListState[T](stateName) + ImplicitGroupingKeyTracker.removeImplicitKey() + return result + } + + /** + * Sets the map state for a given key. + * + * @param stateName the name of the map state variable + * @param key the grouping key + * @param value the map of key-value pairs to set + * @tparam MK the type of keys in the map state + * @tparam MV the type of values in the map state + */ + def setMapState[MK, MV](stateName: String, key: K, value: Map[MK, MV]): Unit = { + ImplicitGroupingKeyTracker.setImplicitKey(key) + handle.setMapState[MK, MV](stateName, value) + ImplicitGroupingKeyTracker.removeImplicitKey() + } + + /** + * Retrieves the map state for a given key without modifying it. + * + * @param stateName the name of the map state variable + * @param key the grouping key + * @tparam MK the type of keys in the map state + * @tparam MV the type of values in the map state + * @return the map of key-value pairs, or an empty map if no state exists for the key + */ + def peekMapState[MK, MV](stateName: String, key: K): Map[MK, MV] = { + ImplicitGroupingKeyTracker.setImplicitKey(key) + val result: Map[MK, MV] = handle.peekMapState[MK, MV](stateName) + ImplicitGroupingKeyTracker.removeImplicitKey() + return result + } + + /** + * Sets watermark for EventTime time mode. + * + * @param eventTime function used to extract timestamp column from input rows. + * @param delayThreshold a string specifying the minimum delay to wait to data to arrive late, + * relative to the latest record that has been processed in the form of an interval (e.g. + * "1 minute" or "5 hours") + */ + def withWatermark(eventTime: (I => Timestamp), delayThreshold: String): Unit = { + require(timeMode == TimeMode.EventTime, "withWatermark is only usable with TimeMode.EventTime") + val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold) + require( + !IntervalUtils.isNegative(parsedDelay), + s"delay threshold ($delayThreshold) should not be negative." + ) + eventTimeFunc = eventTime + delayThresholdMs = + IntervalUtils.getDuration(parsedDelay, java.util.concurrent.TimeUnit.MILLISECONDS, 30) + } +} + +object TwsTester { + + /** Fake implementation of {@code java.time.CLock} to be used with TwsTester to simulate time. */ + class TestClock( + var currentInstant: Instant = Instant.EPOCH, + zone: ZoneId = ZoneId.systemDefault()) + extends Clock { + override def getZone: ZoneId = zone + override def withZone(zone: ZoneId): Clock = new TestClock(currentInstant, zone) + override def instant(): Instant = currentInstant + + def setInstant(instant: Instant): Unit = { + currentInstant = instant + } + + def advanceBy(duration: Duration): Unit = { + currentInstant = currentInstant.plus(duration) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala new file mode 100644 index 000000000000..196b33934e3c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala @@ -0,0 +1,702 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming + +import java.sql.Timestamp +import java.time.{Duration, Instant} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.processors._ + +/** Test suite for TwsTester utility class. */ +class TwsTesterSuite extends SparkFunSuite { + + test("TwsTester should correctly test RunningCountProcessor") { + val input: List[(String, String)] = List( + ("key1", "a"), + ("key2", "b"), + ("key1", "c"), + ("key2", "b"), + ("key1", "c"), + ("key1", "c"), + ("key3", "q") + ) + val tester = new TwsTester(new RunningCountProcessor[String]()) + val ans1: List[(String, Long)] = tester.test(input) + assert(ans1.sorted == List(("key1", 4L), ("key2", 2L), ("key3", 1L)).sorted) + + assert(tester.peekValueState[Long]("count", "key1").get == 4L) + assert(tester.peekValueState[Long]("count", "key2").get == 2L) + assert(tester.peekValueState[Long]("count", "key3").get == 1L) + assert(tester.peekValueState[Long]("count", "key4").isEmpty) + + val ans2 = tester.testOneRow("key1", "q") + assert(ans2 == List(("key1", 5L))) + assert(tester.peekValueState[Long]("count", "key1").get == 5L) + assert(tester.peekValueState[Long]("count", "key2").get == 2L) + + val ans3 = tester.test(List(("key1", "a"), ("key2", "a"))) + assert(ans3.sorted == List(("key1", 6L), ("key2", 3L))) + } + + test("TwsTester should allow direct access to ValueState") { + val processor = new RunningCountProcessor[String]() + val tester = new TwsTester[String, String, (String, Long)](processor) + tester.setValueState[Long]("count", "foo", 5) + tester.test(List(("foo", "a"))) + assert(tester.peekValueState[Long]("count", "foo").get == 6L) + } + + test("TwsTester should correctly test TopKProcessor") { + val input: List[(String, (String, Double))] = List( + ("key2", ("c", 30.0)), + ("key2", ("d", 40.0)), + ("key1", ("b", 2.0)), + ("key1", ("c", 3.0)), + ("key2", ("a", 10.0)), + ("key2", ("b", 20.0)), + ("key3", ("a", 100.0)), + ("key1", ("a", 1.0)) + ) + val tester = new TwsTester(new TopKProcessor(2)) + val ans1 = tester.test(input) + assert( + ans1.sorted == List( + ("key1", 2.0), + ("key1", 3.0), + ("key2", 30.0), + ("key2", 40.0), + ("key3", 100.0) + ) + ) + assert(tester.peekListState[Double]("topK", "key1") == List(3.0, 2.0)) + assert(tester.peekListState[Double]("topK", "key2") == List(40.0, 30.0)) + assert(tester.peekListState[Double]("topK", "key3") == List(100.0)) + assert(tester.peekListState[Double]("topK", "key4").isEmpty) + + val ans2 = tester.test(List(("key1", ("a", 10.0)))) + assert(ans2.sorted == List(("key1", 3.0), ("key1", 10.0))) + assert(tester.peekListState[Double]("topK", "key1") == List(10.0, 3.0)) + } + + test("TwsTester should allow direct access to ListState") { + val tester = new TwsTester(new TopKProcessor(2)) + tester.setListState("topK", "a", List(6.0, 5.0)) + tester.setListState("topK", "b", List(8.0, 7.0)) + tester.testOneRow("a", ("", 10.0)) + tester.testOneRow("b", ("", 7.5)) + tester.testOneRow("c", ("", 1.0)) + + assert(tester.peekListState[Double]("topK", "a") == List(10.0, 6.0)) + assert(tester.peekListState[Double]("topK", "b") == List(8.0, 7.5)) + assert(tester.peekListState[Double]("topK", "c") == List(1.0)) + assert(tester.peekListState[Double]("topK", "d") == List()) + } + + test("TwsTester should correctly test WordFrequencyProcessor") { + val input: List[(String, (String, String))] = List( + ("user1", ("", "hello")), + ("user1", ("", "world")), + ("user1", ("", "hello")), + ("user2", ("", "hello")), + ("user2", ("", "spark")), + ("user1", ("", "world")) + ) + val tester = new TwsTester(new WordFrequencyProcessor()) + val ans1 = tester.test(input) + + assert( + ans1.sorted == List( + ("user1", "hello", 1L), + ("user1", "hello", 2L), + ("user1", "world", 1L), + ("user1", "world", 2L), + ("user2", "hello", 1L), + ("user2", "spark", 1L) + ).sorted + ) + + // Check state using peekMapState + assert( + tester.peekMapState[String, Long]("frequencies", "user1") == Map("hello" -> 2L, "world" -> 2L) + ) + assert( + tester.peekMapState[String, Long]("frequencies", "user2") == Map("hello" -> 1L, "spark" -> 1L) + ) + assert(tester.peekMapState[String, Long]("frequencies", "user3") == Map()) + assert(tester.peekMapState[String, Long]("frequencies", "user3").isEmpty) + + // Process more data for user1 + val ans2 = tester.test(List(("user1", ("", "hello")), ("user1", ("", "test")))) + assert(ans2.sorted == List(("user1", "hello", 3L), ("user1", "test", 1L)).sorted) + assert( + tester.peekMapState[String, Long]("frequencies", "user1") == Map( + "hello" -> 3L, + "world" -> 2L, + "test" -> 1L + ) + ) + } + + test("TwsTester should allow direct access to MapState") { + val tester = new TwsTester(new WordFrequencyProcessor()) + + // Set initial state directly + tester.setMapState("frequencies", "user1", Map("hello" -> 5L, "world" -> 3L)) + tester.setMapState("frequencies", "user2", Map("spark" -> 10L)) + + // Process new words + tester.testOneRow("user1", ("", "hello")) + tester.testOneRow("user1", ("", "goodbye")) + tester.testOneRow("user2", ("", "spark")) + tester.testOneRow("user3", ("", "new")) + + // Verify updated state + assert( + tester.peekMapState[String, Long]("frequencies", "user1") == Map( + "hello" -> 6L, + "world" -> 3L, + "goodbye" -> 1L + ) + ) + assert(tester.peekMapState[String, Long]("frequencies", "user2") == Map("spark" -> 11L)) + assert(tester.peekMapState[String, Long]("frequencies", "user3") == Map("new" -> 1L)) + assert(tester.peekMapState[String, Long]("frequencies", "user4") == Map()) + } + + test("TwsTester should expire old value state according to TTL") { + val processor = new RunningCountProcessor[String](TTLConfig(Duration.ofSeconds(100))) + val testClock = new TwsTester.TestClock(Instant.EPOCH) + val tester = new TwsTester(processor, testClock) + + tester.testOneRow("key1", "b") + assert(tester.peekValueState[Long]("count", "key1").get == 1L) + testClock.advanceBy(Duration.ofSeconds(101)) + assert(tester.peekValueState[Long]("count", "key1").isEmpty) + } + + test("TwsTester should expire old list state according to TTL") { + val testClock = new TwsTester.TestClock(Instant.EPOCH) + val ttlConfig = TTLConfig(Duration.ofSeconds(100)) + val processor = new TopKProcessor(2, ttlConfig) + val tester = new TwsTester(processor, testClock) + + tester.testOneRow("key1", ("a", 1.0)) + tester.testOneRow("key2", ("a", 1.5)) + assert(tester.peekListState[Double]("topK", "key1") == List(1.0)) + assert(tester.peekListState[Double]("topK", "key2") == List(1.5)) + + testClock.advanceBy(Duration.ofSeconds(50)) + tester.testOneRow("key2", ("a", 2.0)) + assert(tester.peekListState[Double]("topK", "key1") == List(1.0)) + assert(tester.peekListState[Double]("topK", "key2") == List(2.0, 1.5)) + + testClock.advanceBy(Duration.ofSeconds(51)) + assert(tester.peekListState[Double]("topK", "key1") == List()) + assert(tester.peekListState[Double]("topK", "key2") == List(2.0, 1.5)) + + testClock.advanceBy(Duration.ofSeconds(50)) + assert(tester.peekListState[Double]("topK", "key1") == List()) + assert(tester.peekListState[Double]("topK", "key2") == List()) + } + + test("TwsTester should expire old map state according to TTL") { + val testClock = new TwsTester.TestClock(Instant.EPOCH) + val ttlConfig = TTLConfig(Duration.ofSeconds(100)) + val processor = new WordFrequencyProcessor(ttlConfig) + val tester = new TwsTester(processor, testClock) + + tester.testOneRow("key1", ("a", "spark")) + tester.testOneRow("key2", ("a", "beta")) + assert(tester.peekMapState[String, Long]("frequencies", "key1") == Map("spark" -> 1L)) + assert(tester.peekMapState[String, Long]("frequencies", "key2") == Map("beta" -> 1L)) + + testClock.advanceBy(Duration.ofSeconds(50)) + tester.testOneRow("key2", ("a", "spark")) + assert(tester.peekMapState[String, Long]("frequencies", "key1") == Map("spark" -> 1L)) + assert( + tester.peekMapState[String, Long]("frequencies", "key2") == Map("beta" -> 1L, "spark" -> 1L) + ) + + testClock.advanceBy(Duration.ofSeconds(51)) + assert(tester.peekMapState[String, Long]("frequencies", "key1") == Map()) + assert( + tester.peekMapState[String, Long]("frequencies", "key2") == Map("beta" -> 1L, "spark" -> 1L) + ) + + testClock.advanceBy(Duration.ofSeconds(50)) + assert(tester.peekMapState[String, Long]("frequencies", "key1") == Map()) + assert(tester.peekMapState[String, Long]("frequencies", "key2") == Map()) + } + + test("TwsTester should test one row with value state") { + val processor = new RunningCountProcessor[String]() + val tester = new TwsTester(processor) + + val (rows, newState) = tester.testOneRowWithValueState("key1", "a", "count", 10L) + assert(rows == List(("key1", 11L))) + assert(newState == 11L) + } + + test("TwsTester should handle session timeout with timer") { + val testClock = new TwsTester.TestClock(Instant.ofEpochMilli(10000L)) + val processor = new SessionTimeoutProcessor(timeoutDurationMs = 5000L) + val tester = new TwsTester(processor, testClock, timeMode = TimeMode.ProcessingTime) + + // First activity - should register timer at 15000ms + val result1 = tester.test(List(("user1", ("user1", "login")))) + assert(result1 == List(("user1", "ACTIVITY", 1L))) + assert(tester.peekValueState[Long]("activityCount", "user1").get == 1L) + assert(tester.peekValueState[Long]("timer", "user1").get == 15000L) + + // Second activity before timeout - should update timer to 20000ms + testClock.advanceBy(Duration.ofMillis(3000L)) // now at 13000ms + val result2 = tester.test(List(("user1", ("user1", "click")))) + assert(result2 == List(("user1", "ACTIVITY", 2L))) + assert(tester.peekValueState[Long]("activityCount", "user1").get == 2L) + assert(tester.peekValueState[Long]("timer", "user1").get == 18000L) + + // Advance time past timeout - timer should fire + testClock.advanceBy(Duration.ofMillis(6000L)) // now at 19000ms, timer at 18000ms should fire + val result3 = tester.test(List()) // empty input, but timer should fire + assert(result3 == List(("user1", "SESSION_TIMEOUT", 2L))) + assert(tester.peekValueState[Long]("activityCount", "user1").isEmpty) + assert(tester.peekValueState[Long]("timer", "user1").isEmpty) + } + + test("TwsTester should process input before timers") { + val testClock = new TwsTester.TestClock(Instant.ofEpochMilli(10000L)) + val processor = new SessionTimeoutProcessor(timeoutDurationMs = 5000L) + val tester = new TwsTester(processor, testClock, timeMode = TimeMode.ProcessingTime) + + // Register initial activity and timer + tester.test(List(("user1", ("user1", "start")))) + assert(tester.peekValueState[Long]("activityCount", "user1").get == 1L) + + // Advance past timer expiry + testClock.advanceBy(Duration.ofMillis(6000L)) // now at 16000ms, timer at 15000ms expired + + // Process new input - input should be processed BEFORE timer fires + val result = tester.test(List(("user1", ("user1", "new_activity")))) + + // Input is processed first: increments count to 2, deletes old timer, registers new one + // The expired timer at 15000ms is deleted during input processing, so it never fires + assert(result.length == 1) + assert(result(0) == ("user1", "ACTIVITY", 2L)) // Input processed, count incremented from 1 to 2 + + // After processing, should have updated state with new timer + assert(tester.peekValueState[Long]("activityCount", "user1").get == 2L) + assert(tester.peekValueState[Long]("timer", "user1").get == 21000L) // 16000 + 5000 + } + + test("TwsTester should handle multiple timers in same batch") { + val testClock = new TwsTester.TestClock(Instant.ofEpochMilli(10000L)) + val processor = new MultiTimerProcessor() + val tester = new TwsTester(processor, testClock, timeMode = TimeMode.ProcessingTime) + + // Register all three timers (SHORT=11000ms, MEDIUM=13000ms, LONG=15000ms) + val result1 = tester.test(List(("user1", ("user1", "start")))) + assert(result1.isEmpty) + + // Advance to fire SHORT timer only + testClock.advanceBy(Duration.ofMillis(1500L)) // now at 11500ms + val result2 = tester.test(List()) + assert(result2.length == 1) + assert(result2(0) == ("user1", "SHORT", 11000L)) + + // Advance to fire MEDIUM and LONG timers together + testClock.advanceBy(Duration.ofMillis(4000L)) // now at 15500ms + val result3 = tester.test(List()) + + // Timers should fire in order of expiry time + assert(result3.length == 2) + assert(result3(0) == ("user1", "MEDIUM", 13000L)) + assert(result3(1) == ("user1", "LONG", 15000L)) + } + + test("TwsTester should not process timers twice") { + val testClock = new TwsTester.TestClock(Instant.ofEpochMilli(10000L)) + val processor = new SessionTimeoutProcessor(timeoutDurationMs = 5000L) + val tester = new TwsTester(processor, testClock, timeMode = TimeMode.ProcessingTime) + + // Register timer at 15000ms + tester.test(List(("user1", ("user1", "start")))) + assert(tester.peekValueState[Long]("activityCount", "user1").get == 1L) + + // Advance past timer and process - timer fires + testClock.advanceBy(Duration.ofMillis(6000L)) // now at 16000ms + val result1 = tester.test(List()) + assert(result1.length == 1) + assert(result1(0) == ("user1", "SESSION_TIMEOUT", 1L)) + assert(tester.peekValueState[Long]("activityCount", "user1").isEmpty) + + // Process again at same time - timer should NOT fire again + val result2 = tester.test(List()) + assert(result2.isEmpty) + + // Process again with even later time - timer should still not fire + testClock.advanceBy(Duration.ofMillis(10000L)) // now at 26000ms + val result3 = tester.test(List()) + assert(result3.isEmpty) + } + + test("TwsTester should handle timers for multiple keys independently") { + val testClock = new TwsTester.TestClock(Instant.ofEpochMilli(10000L)) + val processor = new SessionTimeoutProcessor(timeoutDurationMs = 5000L) + val tester = new TwsTester(processor, testClock, timeMode = TimeMode.ProcessingTime) + + // Register activities for two users at different times + tester.test(List(("user1", ("user1", "start")))) // timer at 15000ms + testClock.advanceBy(Duration.ofMillis(2000L)) // now at 12000ms + tester.test(List(("user2", ("user2", "start")))) // timer at 17000ms + + // Advance to fire only user1's timer + testClock.advanceBy(Duration.ofMillis(4000L)) // now at 16000ms + val result1 = tester.test(List()) + assert(result1.length == 1) + assert(result1(0) == ("user1", "SESSION_TIMEOUT", 1L)) + assert(tester.peekValueState[Long]("activityCount", "user1").isEmpty) + assert(tester.peekValueState[Long]("activityCount", "user2").get == 1L) // user2 still active + + // Advance to fire user2's timer + testClock.advanceBy(Duration.ofMillis(2000L)) // now at 18000ms + val result2 = tester.test(List()) + assert(result2.length == 1) + assert(result2(0) == ("user2", "SESSION_TIMEOUT", 1L)) + assert(tester.peekValueState[Long]("activityCount", "user2").isEmpty) + } + + test("TwsTester should handle timer deletion correctly") { + val testClock = new TwsTester.TestClock(Instant.ofEpochMilli(10000L)) + val processor = new SessionTimeoutProcessor(timeoutDurationMs = 5000L) + val tester = new TwsTester(processor, testClock, timeMode = TimeMode.ProcessingTime) + + // Register initial timer at 15000ms + tester.test(List(("user1", ("user1", "start")))) + assert(tester.peekValueState[Long]("timer", "user1").get == 15000L) + + // New activity deletes old timer and registers new one at 18000ms + testClock.advanceBy(Duration.ofMillis(3000L)) // now at 13000ms + tester.test(List(("user1", ("user1", "activity")))) + assert(tester.peekValueState[Long]("timer", "user1").get == 18000L) + + // Advance past original timer time (15000ms) but before new timer (18000ms) + testClock.advanceBy(Duration.ofMillis(3000L)) // now at 16000ms + val result = tester.test(List()) + // Old timer at 15000ms should NOT fire since it was deleted + assert(result.isEmpty) + assert(tester.peekValueState[Long]("activityCount", "user1").get == 2L) // still active + + // Advance to EXACTLY the new timer time (18000ms). + // This tests that timers fire when current time equals expiry time. + testClock.advanceBy(Duration.ofMillis(2000L)) + val result2 = tester.test(List()) + assert(result2.length == 1) + assert(result2(0) == ("user1", "SESSION_TIMEOUT", 2L)) + } + + test("TwsTester should handle EventTime mode with watermark") { + val processor = new EventTimeWindowProcessor(windowDurationMs = 10000L) + val tester = new TwsTester(processor, timeMode = TimeMode.EventTime) + + // Configure watermark: extract timestamp from input, 2 second delay + // In real Spark, the watermark column must be of Timestamp type + tester.withWatermark( + (input: (String, Timestamp)) => input._2, + "2 seconds" + ) + + // Batch 1: Process events with timestamps 10000, 12000, 15000 + // Max event time = 15000, watermark after batch = 15000 - 2000 = 13000 + val result1 = tester.test( + List( + ("user1", ("event1", new Timestamp(10000L))), + ("user1", ("event2", new Timestamp(12000L))), + ("user1", ("event3", new Timestamp(15000L))) + ) + ) + // First batch: registers timer at 20000 (10000 + 10000), no timers fire yet + assert(result1 == List(("user1", "WINDOW_START", 3L))) + assert(tester.peekValueState[Long]("eventCount", "user1").get == 3L) + assert(tester.peekValueState[Long]("windowEndTime", "user1").get == 20000L) + + // Batch 2: Process more events with timestamps 18000, 20000 + // Watermark before batch = 13000, so timer at 20000 doesn't fire yet + // Max event time = 20000, watermark after batch = 20000 - 2000 = 18000 + // Timer at 20000 still doesn't fire (watermark < timer) + val result2 = tester.test( + List( + ("user1", ("event4", new Timestamp(18000L))), + ("user1", ("event5", new Timestamp(20000L))) + ) + ) + assert(result2 == List(("user1", "WINDOW_CONTINUE", 5L))) + assert(tester.peekValueState[Long]("eventCount", "user1").get == 5L) + + // Batch 3: Process event with timestamp 23000 + // Watermark before batch = 18000 (for filtering late events) + // Max event time = 23000, watermark after batch = 23000 - 2000 = 21000 + // Timer at 20000 FIRES because updated watermark (21000) > timer (20000) + // Input is processed first, then timer fires + val result3 = tester.test(List(("user1", ("event6", new Timestamp(23000L))))) + // Input is processed first (increments count to 6), then timer fires (outputs final count 6) + assert(result3.length == 2) + assert(result3(0) == ("user1", "WINDOW_CONTINUE", 6L)) // Input processed first + assert(result3(1) == ("user1", "WINDOW_END", 6L)) // Timer fired with count after input + // Timer clears state, so no state should exist after this batch + assert(tester.peekValueState[Long]("eventCount", "user1").isEmpty) + assert(tester.peekValueState[Long]("windowEndTime", "user1").isEmpty) + + // Batch 4: Process another event with timestamp 25000 + // State was cleared in batch 3, so this starts a new window + // Watermark before batch = 21000 + // Max event time = 25000, watermark after batch = 25000 - 2000 = 23000 + val result4 = tester.test(List(("user1", ("event7", new Timestamp(25000L))))) + // Since state was cleared, this starts a new window + assert(result4 == List(("user1", "WINDOW_START", 1L))) + assert(tester.peekValueState[Long]("eventCount", "user1").get == 1L) + assert(tester.peekValueState[Long]("windowEndTime", "user1").get == 35000L) // 25000 + 10000 + } + + test("TwsTester should filter late events based on watermark") { + val processor = new EventTimeWindowProcessor(windowDurationMs = 10000L) + val tester = new TwsTester(processor, timeMode = TimeMode.EventTime) + + // Configure watermark with 2 second delay + tester.withWatermark( + (input: (String, Timestamp)) => input._2, + "2 seconds" + ) + + // Batch 1: Process events with timestamps 10000, 12000, 15000 + // Max event time = 15000, watermark after batch = 15000 - 2000 = 13000 + val result1 = tester.test( + List( + ("user1", ("event1", new Timestamp(10000L))), + ("user1", ("event2", new Timestamp(12000L))), + ("user1", ("event3", new Timestamp(15000L))) + ) + ) + assert(result1 == List(("user1", "WINDOW_START", 3L))) + + // Batch 2: Send mix of late and on-time events + // Watermark is currently 13000 + // Event at 11000 is late (< 13000) and should be filtered + // Event at 14000 is on-time (>= 13000) and should be processed + val result2 = tester.test( + List( + ("user1", ("late_event", new Timestamp(11000L))), // LATE - should be filtered + ("user1", ("on_time_event", new Timestamp(14000L))) // ON-TIME - should be processed + ) + ) + + // Only the on-time event should be processed, so count goes from 3 to 4 + assert(result2 == List(("user1", "WINDOW_CONTINUE", 4L))) + assert(tester.peekValueState[Long]("eventCount", "user1").get == 4L) + } + + test("TwsTester should call handleInitialState") { + val processor = new RunningCountProcessor[String]() + val tester = new TwsTester(processor, initialState = List(("a", 10L), ("b", 20L))) + assert(tester.peekValueState[Long]("count", "a").get == 10L) + assert(tester.peekValueState[Long]("count", "b").get == 20L) + + val ans = tester.test(List(("a", "a"), ("c", "c"))) + assert(ans == List(("a", 11L), ("c", 1L))) + } + + test("TwsTester should test RunningCountProcessor row-by-row") { + val input: List[(String, String)] = List( + ("key1", "a"), + ("key2", "b"), + ("key1", "c"), + ("key2", "b"), + ("key1", "c"), + ("key1", "c"), + ("key3", "q") + ) + val tester = new TwsTester(new RunningCountProcessor[String]()) + val ans: List[(String, Long)] = tester.testRowByRow(input) + assert( + ans == List( + ("key1", 1L), + ("key2", 1L), + ("key1", 2L), + ("key2", 2L), + ("key1", 3L), + ("key1", 4L), + ("key3", 1L) + ) + ) + } +} + +/** + * Integration test suite that compares TwsTester results with real streaming execution. + * Thread auditing is disabled because this suite runs actual streaming queries with RocksDB + * and shuffle operations, which spawn daemon threads (e.g., Netty boss/worker threads, + * file client threads, ForkJoinPool workers, and cleaner threads) that shut down + * asynchronously after SparkContext.stop(). + */ +class TwsTesterFuzzTestSuite extends StreamTest { + import testImplicits._ + + // Disable thread auditing for this suite since it runs integration tests with + // real streaming queries that create asynchronously-stopped threads + override protected val enableAutoThreadAudit = false + + /** + * Asserts that {@code tester} is equivalent to streaming query transforming {@code inputStream} + * to {@code result}, when both are fed with data from {@code batches}. + */ + def checkTwsTester[ + K: org.apache.spark.sql.Encoder, + I: org.apache.spark.sql.Encoder, + O: org.apache.spark.sql.Encoder]( + tester: TwsTester[K, I, O], + batches: List[List[(K, I)]], + inputStream: MemoryStream[(K, I)], + result: Dataset[O]): Unit = { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "5" + ) { + val expectedResults: List[List[O]] = batches.map(batch => tester.test(batch)).toList + assert(batches.size == expectedResults.size) + + val actions: Seq[StreamAction] = (batches zip expectedResults).flatMap { + case (batch, expected) => + Seq( + AddData(inputStream, batch: _*), + CheckNewAnswer(expected.head, expected.tail: _*) + ) + } :+ StopStream + testStream(result, OutputMode.Append())(actions: _*) + } + } + + /** + * Asserts that {@code tester} processes given {@code input} in the same way as Spark streaming + * query with {@code transformWithState} would. + * + * This is simplified version of {@code checkTwsTester} for the case where there is only one batch + * and no timers (time mode is TimeMode.None). + */ + def checkTwsTesterOneBatch[ + K: org.apache.spark.sql.Encoder, + I: org.apache.spark.sql.Encoder, + O: org.apache.spark.sql.Encoder]( + processor: StatefulProcessor[K, I, O], + input: List[(K, I)]): Unit = { + implicit val tupleEncoder = org.apache.spark.sql.Encoders.tuple( + implicitly[org.apache.spark.sql.Encoder[K]], + implicitly[org.apache.spark.sql.Encoder[I]] + ) + val inputStream = MemoryStream[(K, I)] + val result = inputStream + .toDS() + .groupByKey(_._1) + .mapValues(_._2) + .transformWithState(processor, TimeMode.None(), OutputMode.Append()) + checkTwsTester(new TwsTester(processor), List(input), inputStream, result) + } + + test("fuzz test with RunningCountProcessor") { + val random = new scala.util.Random(0) + val input = List.fill(1000) { + (s"key${random.nextInt(10)}", random.alphanumeric.take(5).mkString) + } + val processor = new RunningCountProcessor[String]() + checkTwsTesterOneBatch(processor, input) + } + + test("fuzz test with TopKProcessor") { + val random = new scala.util.Random(0) + val input = List.fill(1000) { + ( + s"key${random.nextInt(10)}", + (random.alphanumeric.take(5).mkString, random.nextDouble() * 100) + ) + } + val processor = new TopKProcessor(5) + checkTwsTesterOneBatch(processor, input) + } + + test("fuzz test with WordFrequencyProcessor") { + val random = new scala.util.Random(0) + val words = Array("spark", "scala", "flink", "kafka", "hadoop", "hive", "presto", "trino") + val input = List.fill(1000) { + (s"key${random.nextInt(10)}", ("", words(random.nextInt(words.length)))) + } + val processor = new WordFrequencyProcessor() + checkTwsTesterOneBatch(processor, input) + } + + test("fuzz test with EventTimeWindowProcessor") { + val inputStream = MemoryStream[(String, (String, Timestamp))] + val processor = new EventTimeWindowProcessor(windowDurationMs = 10000L) + val result = inputStream + .toDS() + .select($"_1", $"_2", $"_2._2".as("timestamp")) + .withWatermark("timestamp", "2 seconds") + .as[(String, (String, Timestamp), Timestamp)] + .groupByKey(_._1) + .mapValues(_._2) + .transformWithState(processor, TimeMode.EventTime(), OutputMode.Append()) + + // Generate 10 random batches, each with ~100 events, 10 users, with timestamps increasing. + val random = new scala.util.Random(0) + val numBatches = 10 + val numUsers = 10 + val eventsPerBatch = 100 + val eventGapMs = 1000L + val userIds = (1 to numUsers).map(i => s"user$i").toArray + + var currentTimestamp = 0L + val batches: List[List[(String, (String, Timestamp))]] = + (0 until numBatches).map { batchIdx => + val batchStart = currentTimestamp + val usersInThisBatch = random.shuffle(userIds.toList) + // Assign variable number of events to each user in this batch + val perUserEvents = Array.fill(numUsers)(random.nextInt(5) + 5) // 5-9 events per user + val events = + usersInThisBatch.zip(perUserEvents).flatMap { + case (user, numEventsForUser) => + (0 until numEventsForUser).map { evtIdx => + // Make timestamp within this batch but always increasing overall + val ts = new Timestamp(currentTimestamp) + val evtId = s"event${batchIdx}_${user}_$evtIdx" + currentTimestamp += eventGapMs + (user, (evtId, ts)) + } + } + random + .shuffle(events) + .take(eventsPerBatch) + .toList // Shuffle and trim to ~100 events per batch + }.toList + + val tester = new TwsTester(processor, timeMode = TimeMode.EventTime()) + tester.withWatermark(_._2, "2 seconds") + + checkTwsTester(tester, batches, inputStream, result) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/EventTimeWindowProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/EventTimeWindowProcessor.scala new file mode 100644 index 000000000000..717b6747abb6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/EventTimeWindowProcessor.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming.processors + +import java.sql.Timestamp + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.streaming.{ExpiredTimerInfo, OutputMode, StatefulProcessor, TimeMode, TimerValues, TTLConfig, ValueState} + +/** + * Event time window processor that demonstrates event time timer usage. + * + * Input: (eventId, eventTime) as (String, Timestamp) + * Output: (userId, status, count) as (String, String, Long) + * + * Behavior: + * - Tracks event count per user window + * - On first event, registers a timer for windowDurationMs from the first event time + * - Accumulates events in the window + * - When timer expires (based on watermark), emits window summary and starts new window + */ +class EventTimeWindowProcessor(val windowDurationMs: Long = 10000L) + extends StatefulProcessor[String, (String, Timestamp), (String, String, Long)] { + + @transient private var eventCountState: ValueState[Long] = _ + @transient private var windowEndTimeState: ValueState[Long] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + eventCountState = + getHandle.getValueState[Long]("eventCount", Encoders.scalaLong, TTLConfig.NONE) + windowEndTimeState = + getHandle.getValueState[Long]("windowEndTime", Encoders.scalaLong, TTLConfig.NONE) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, Timestamp)], + timerValues: TimerValues + ): Iterator[(String, String, Long)] = { + val events = inputRows.toList + val currentCount = if (eventCountState.exists()) eventCountState.get() else 0L + val newCount = currentCount + events.size + eventCountState.update(newCount) + + // If this is the first event in a window, register timer + if (!windowEndTimeState.exists()) { + val firstEventTime = events.head._2.getTime + val windowEnd = firstEventTime + windowDurationMs + getHandle.registerTimer(windowEnd) + windowEndTimeState.update(windowEnd) + Iterator.single((key, "WINDOW_START", newCount)) + } else { + Iterator.single((key, "WINDOW_CONTINUE", newCount)) + } + } + + override def handleExpiredTimer( + key: String, + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo + ): Iterator[(String, String, Long)] = { + val count = if (eventCountState.exists()) eventCountState.get() else 0L + + // Clear window state + eventCountState.clear() + windowEndTimeState.clear() + + Iterator.single((key, "WINDOW_END", count)) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/MultiTimerProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/MultiTimerProcessor.scala new file mode 100644 index 000000000000..b9d3eb04e46d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/MultiTimerProcessor.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming.processors + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.streaming.{ExpiredTimerInfo, MapState, OutputMode, StatefulProcessor, TimeMode, TimerValues, TTLConfig} + +/** + * Multi-timer processor that registers multiple timers with different delays. + * + * Input: (userId, command) as (String, String) + * Output: (userId, timerType, timestamp) as (String, String, Long) + * + * On first input, registers three timers: SHORT (1s), MEDIUM (3s), LONG (5s) + * When each timer expires, emits the timer type and timestamp + */ +class MultiTimerProcessor + extends StatefulProcessor[String, (String, String), (String, String, Long)] { + + @transient private var timerTypeMapState: MapState[Long, String] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + timerTypeMapState = getHandle.getMapState[Long, String]( + "timerTypeMap", + Encoders.scalaLong, + Encoders.STRING, + TTLConfig.NONE + ) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues + ): Iterator[(String, String, Long)] = { + inputRows.size // consume iterator + + val baseTime = timerValues.getCurrentProcessingTimeInMs() + val shortTimer = baseTime + 1000L + val mediumTimer = baseTime + 3000L + val longTimer = baseTime + 5000L + + getHandle.registerTimer(shortTimer) // SHORT - 1 second + getHandle.registerTimer(mediumTimer) // MEDIUM - 3 seconds + getHandle.registerTimer(longTimer) // LONG - 5 seconds + + timerTypeMapState.updateValue(shortTimer, "SHORT") + timerTypeMapState.updateValue(mediumTimer, "MEDIUM") + timerTypeMapState.updateValue(longTimer, "LONG") + + Iterator.empty + } + + override def handleExpiredTimer( + key: String, + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo + ): Iterator[(String, String, Long)] = { + val expiryTime = expiredTimerInfo.getExpiryTimeInMs() + val timerType = timerTypeMapState.getValue(expiryTime) + + // Clean up the timer type from the map + timerTypeMapState.removeKey(expiryTime) + + Iterator.single((key, timerType, expiryTime)) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/RunningCountProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/RunningCountProcessor.scala new file mode 100644 index 000000000000..c0c2ddfa5679 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/RunningCountProcessor.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming.processors + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessorWithInitialState, TimeMode, TimerValues, TTLConfig, ValueState} + +/** Test StatefulProcessor implementation that maintains a running count. */ +class RunningCountProcessor[T](ttl: TTLConfig = TTLConfig.NONE) + extends StatefulProcessorWithInitialState[String, T, (String, Long), Long] { + + @transient private var countState: ValueState[Long] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + countState = getHandle.getValueState[Long]("count", Encoders.scalaLong, ttl) + } + + override def handleInitialState( + key: String, + initialState: Long, + timerValues: TimerValues + ): Unit = { + countState.update(initialState) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[T], + timerValues: TimerValues + ): Iterator[(String, Long)] = { + val incoming = inputRows.size + val current = countState.get() + val updated = current + incoming + countState.update(updated) + Iterator.single((key, updated)) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/SessionTimeoutProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/SessionTimeoutProcessor.scala new file mode 100644 index 000000000000..06f79720f9f2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/SessionTimeoutProcessor.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming.processors + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.streaming.{ExpiredTimerInfo, OutputMode, StatefulProcessor, TimeMode, TimerValues, TTLConfig, ValueState} + +/** + * Session timeout processor that demonstrates timer usage. + * + * Input: (userId, activityType) as (String, String) + * Output: (userId, event, count) as (String, String, Long) + * + * Behavior: + * - Tracks activity count per user session + * - Registers a timeout timer on first activity (5 seconds from current time) + * - Updates timer on each new activity (resets 5-second countdown) + * - When timer expires, emits ("userId", "SESSION_TIMEOUT", activityCount) and clears state + * - Regular activities emit ("userId", "ACTIVITY", currentCount) + */ +class SessionTimeoutProcessor(val timeoutDurationMs: Long = 5000L) + extends StatefulProcessor[String, (String, String), (String, String, Long)] { + + @transient private var activityCountState: ValueState[Long] = _ + @transient private var timerState: ValueState[Long] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + activityCountState = + getHandle.getValueState[Long]("activityCount", Encoders.scalaLong, TTLConfig.NONE) + timerState = getHandle.getValueState[Long]("timer", Encoders.scalaLong, TTLConfig.NONE) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues + ): Iterator[(String, String, Long)] = { + val currentCount = if (activityCountState.exists()) activityCountState.get() else 0L + val newCount = currentCount + inputRows.size + activityCountState.update(newCount) + + // Delete old timer if exists and register new one + if (timerState.exists()) { + getHandle.deleteTimer(timerState.get()) + } + + val newTimerMs = timerValues.getCurrentProcessingTimeInMs() + timeoutDurationMs + getHandle.registerTimer(newTimerMs) + timerState.update(newTimerMs) + + Iterator.single((key, "ACTIVITY", newCount)) + } + + override def handleExpiredTimer( + key: String, + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo + ): Iterator[(String, String, Long)] = { + val count = if (activityCountState.exists()) activityCountState.get() else 0L + + // Clear session state + activityCountState.clear() + timerState.clear() + + Iterator.single((key, "SESSION_TIMEOUT", count)) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/TopKProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/TopKProcessor.scala new file mode 100644 index 000000000000..48733cce5e5a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/TopKProcessor.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming.processors + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.streaming.{ListState, OutputMode, StatefulProcessor, TimeMode, TimerValues, TTLConfig} + +// Input: (key, score) as (String, Double) +// Output: (key, score) as (String, Double) for the top K snapshot each batch +class TopKProcessor(k: Int, ttl: TTLConfig = TTLConfig.NONE) + extends StatefulProcessor[String, (String, Double), (String, Double)] { + + @transient private var topKState: ListState[Double] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + topKState = getHandle.getListState[Double]("topK", Encoders.scalaDouble, ttl) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, Double)], + timerValues: TimerValues + ): Iterator[(String, Double)] = { + // Load existing list into a buffer + val current = ArrayBuffer[Double]() + topKState.get().foreach(current += _) + println(s"AAA loaded state for key=$key: $current") + + // Add new values and recompute top-K + inputRows.foreach { + case (_, score) => + current += score + } + val updatedTopK = current.sorted(Ordering[Double].reverse).take(k) + println(s"AAA updatedTopK for key=$key: $updatedTopK") + + // Persist back + topKState.clear() + topKState.put(updatedTopK.toArray) + + // Emit snapshot of top-K for this key + updatedTopK.iterator.map(v => (key, v)) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/WordFrequencyProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/WordFrequencyProcessor.scala new file mode 100644 index 000000000000..02b11b7715e3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/WordFrequencyProcessor.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming.processors + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.streaming.{MapState, OutputMode, StatefulProcessor, TimeMode, TimerValues, TTLConfig} + +// Input: (key, word) as (String, String) +// Output: (key, word, count) as (String, String, Long) for each word in the batch +class WordFrequencyProcessor(ttl: TTLConfig = TTLConfig.NONE) + extends StatefulProcessor[String, (String, String), (String, String, Long)] { + + @transient private var freqState: MapState[String, Long] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + freqState = getHandle + .getMapState[String, Long]("frequencies", Encoders.STRING, Encoders.scalaLong, ttl) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues + ): Iterator[(String, String, Long)] = { + val results = ArrayBuffer[(String, String, Long)]() + + inputRows.foreach { + case (_, word) => + val currentCount = if (freqState.containsKey(word)) { + freqState.getValue(word) + } else { + 0L + } + val updatedCount = currentCount + 1 + freqState.updateValue(word, updatedCount) + results += ((key, word, updatedCount)) + } + + results.iterator + } +} + From 1e2c8a54e91568881a3b5b95f7cc168423150685 Mon Sep 17 00:00:00 2001 From: Dmytro Fedoriaka Date: Fri, 21 Nov 2025 21:11:27 +0000 Subject: [PATCH 2/7] remove timer functionality --- .../InMemoryStatefulProcessorHandleImpl.scala | 122 +----- .../spark/sql/streaming/TwsTester.scala | 124 +----- .../spark/sql/streaming/TwsTesterSuite.scala | 385 +----------------- .../processors/EventTimeWindowProcessor.scala | 85 ---- .../processors/MultiTimerProcessor.scala | 82 ---- .../processors/SessionTimeoutProcessor.scala | 82 ---- .../streaming/processors/TopKProcessor.scala | 2 - 7 files changed, 23 insertions(+), 859 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/EventTimeWindowProcessor.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/MultiTimerProcessor.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/SessionTimeoutProcessor.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala index 2a76cbb470c7..3da55a58dd8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala @@ -16,67 +16,27 @@ */ package org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.testing -import java.time.Clock -import java.time.Instant import java.util.UUID import scala.collection.mutable import scala.reflect.ClassTag import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.ImplicitGroupingKeyTracker import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.QueryInfoImpl -import org.apache.spark.sql.streaming.ListState -import org.apache.spark.sql.streaming.MapState -import org.apache.spark.sql.streaming.QueryInfo -import org.apache.spark.sql.streaming.StatefulProcessorHandle -import org.apache.spark.sql.streaming.TimeMode -import org.apache.spark.sql.streaming.TTLConfig -import org.apache.spark.sql.streaming.ValueState - -/** Helper to track expired keys. */ -class TtlTracker(val clock: Clock, ttl: TTLConfig) { - require(!ttl.ttlDuration.isNegative()) - private val keyToLastUpdatedTime = mutable.Map[Any, Instant]() - - def isKeyExpired(): Boolean = { - if (ttl.ttlDuration.isZero()) { - return false - } - val key = ImplicitGroupingKeyTracker.getImplicitKeyOption.get - if (!keyToLastUpdatedTime.contains(key)) { - return false - } - val expiration: Instant = keyToLastUpdatedTime.get(key).get.plus(ttl.ttlDuration) - return expiration.isBefore(clock.instant()) - } +import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TTLConfig, ValueState} - def onKeyUpdated(): Unit = { - val key = ImplicitGroupingKeyTracker.getImplicitKeyOption.get - keyToLastUpdatedTime.put(key, clock.instant()) - } -} - -class InMemoryValueState[T](clock: Clock, ttl: TTLConfig) extends ValueState[T] { +class InMemoryValueState[T] extends ValueState[T] { private val keyToStateValue = mutable.Map[Any, T]() - private val ttlTracker = new TtlTracker(clock, ttl) private def getValue: Option[T] = { - if (ttlTracker.isKeyExpired()) { - return None - } keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) } - override def exists(): Boolean = { - getValue.isDefined - } - + override def exists(): Boolean = getValue.isDefined override def get(): T = getValue.getOrElse(null.asInstanceOf[T]) override def update(newState: T): Unit = { - ttlTracker.onKeyUpdated() keyToStateValue.put(ImplicitGroupingKeyTracker.getImplicitKeyOption.get, newState) } @@ -85,14 +45,10 @@ class InMemoryValueState[T](clock: Clock, ttl: TTLConfig) extends ValueState[T] } } -class InMemoryListState[T](clock: Clock, ttl: TTLConfig) extends ListState[T] { +class InMemoryListState[T] extends ListState[T] { private val keyToStateValue = mutable.Map[Any, mutable.ArrayBuffer[T]]() - private val ttlTracker = new TtlTracker(clock, ttl) private def getList: Option[mutable.ArrayBuffer[T]] = { - if (ttlTracker.isKeyExpired()) { - return None - } keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) } @@ -103,7 +59,6 @@ class InMemoryListState[T](clock: Clock, ttl: TTLConfig) extends ListState[T] { } override def put(newState: Array[T]): Unit = { - ttlTracker.onKeyUpdated() keyToStateValue.put( ImplicitGroupingKeyTracker.getImplicitKeyOption.get, mutable.ArrayBuffer.empty[T] ++ newState @@ -111,7 +66,6 @@ class InMemoryListState[T](clock: Clock, ttl: TTLConfig) extends ListState[T] { } override def appendValue(newState: T): Unit = { - ttlTracker.onKeyUpdated() if (!exists()) { keyToStateValue.put( ImplicitGroupingKeyTracker.getImplicitKeyOption.get, @@ -122,7 +76,6 @@ class InMemoryListState[T](clock: Clock, ttl: TTLConfig) extends ListState[T] { } override def appendList(newState: Array[T]): Unit = { - ttlTracker.onKeyUpdated() if (!exists()) { keyToStateValue.put( ImplicitGroupingKeyTracker.getImplicitKeyOption.get, @@ -137,14 +90,10 @@ class InMemoryListState[T](clock: Clock, ttl: TTLConfig) extends ListState[T] { } } -class InMemoryMapState[K, V](clock: Clock, ttl: TTLConfig) extends MapState[K, V] { +class InMemoryMapState[K, V] extends MapState[K, V] { private val keyToStateValue = mutable.Map[Any, mutable.HashMap[K, V]]() - private val ttlTracker = new TtlTracker(clock, ttl) private def getMap: Option[mutable.HashMap[K, V]] = { - if (ttlTracker.isKeyExpired()) { - return None - } keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) } @@ -165,7 +114,6 @@ class InMemoryMapState[K, V](clock: Clock, ttl: TTLConfig) extends MapState[K, V } override def updateValue(key: K, value: V): Unit = { - ttlTracker.onKeyUpdated() if (!exists()) { keyToStateValue.put( ImplicitGroupingKeyTracker.getImplicitKeyOption.get, @@ -211,44 +159,7 @@ class InMemoryMapState[K, V](clock: Clock, ttl: TTLConfig) extends MapState[K, V } } -class InMemoryTimers { - private val keyToTimers = mutable.Map[Any, mutable.TreeSet[Long]]() - - def registerTimer(expiryTimestampMs: Long): Unit = { - val groupingKey = ImplicitGroupingKeyTracker.getImplicitKeyOption.get - if (!keyToTimers.contains(groupingKey)) { - keyToTimers.put(groupingKey, mutable.TreeSet[Long]()) - } - keyToTimers(groupingKey).add(expiryTimestampMs) - } - - def deleteTimer(expiryTimestampMs: Long): Unit = { - val groupingKey = ImplicitGroupingKeyTracker.getImplicitKeyOption.get - if (keyToTimers.contains(groupingKey)) { - keyToTimers(groupingKey).remove(expiryTimestampMs) - if (keyToTimers(groupingKey).isEmpty) { - keyToTimers.remove(groupingKey) - } - } - } - - def listTimers(): Iterator[Long] = { - val groupingKey = ImplicitGroupingKeyTracker.getImplicitKeyOption.get - keyToTimers.get(groupingKey) match { - case Some(timers) => timers.iterator - case None => Iterator.empty - } - } - - def getAllKeysWithTimers(): Iterator[Any] = { - keyToTimers.keys.iterator - } -} - -class InMemoryStatefulProcessorHandleImpl( - timeMode: TimeMode, - keyExprEnc: ExpressionEncoder[Any], - clock: Clock = Clock.systemUTC()) +class InMemoryStatefulProcessorHandleImpl() extends StatefulProcessorHandle { private val states = mutable.Map[String, Any]() @@ -260,7 +171,7 @@ class InMemoryStatefulProcessorHandleImpl( ): ValueState[T] = { require(!states.contains(stateName), s"State $stateName already defined.") states - .getOrElseUpdate(stateName, new InMemoryValueState[T](clock, ttlConfig)) + .getOrElseUpdate(stateName, new InMemoryValueState[T]()) .asInstanceOf[InMemoryValueState[T]] } @@ -275,7 +186,7 @@ class InMemoryStatefulProcessorHandleImpl( ): ListState[T] = { require(!states.contains(stateName), s"State $stateName already defined.") states - .getOrElseUpdate(stateName, new InMemoryListState[T](clock, ttlConfig)) + .getOrElseUpdate(stateName, new InMemoryListState[T]()) .asInstanceOf[InMemoryListState[T]] } @@ -291,7 +202,7 @@ class InMemoryStatefulProcessorHandleImpl( ): MapState[K, V] = { require(!states.contains(stateName), s"State $stateName already defined.") states - .getOrElseUpdate(stateName, new InMemoryMapState[K, V](clock, ttlConfig)) + .getOrElseUpdate(stateName, new InMemoryMapState[K, V]()) .asInstanceOf[InMemoryMapState[K, V]] } @@ -306,21 +217,16 @@ class InMemoryStatefulProcessorHandleImpl( new QueryInfoImpl(UUID.randomUUID(), UUID.randomUUID(), 0L) } - private val timers = new InMemoryTimers() - override def registerTimer(expiryTimestampMs: Long): Unit = { - require(timeMode != TimeMode.None, "Timers are not supported with TimeMode.None.") - timers.registerTimer(expiryTimestampMs) + throw new UnsupportedOperationException("Timers are not supported.") } override def deleteTimer(expiryTimestampMs: Long): Unit = { - require(timeMode != TimeMode.None, "Timers are not supported with TimeMode.None.") - timers.deleteTimer(expiryTimestampMs) + throw new UnsupportedOperationException("Timers are not supported.") } override def listTimers(): Iterator[Long] = { - require(timeMode != TimeMode.None, "Timers are not supported with TimeMode.None.") - timers.listTimers() + throw new UnsupportedOperationException("Timers are not supported.") } override def deleteIfExists(stateName: String): Unit = { @@ -359,8 +265,4 @@ class InMemoryStatefulProcessorHandleImpl( require(states.contains(stateName), s"State $stateName has not been initialized.") states(stateName).asInstanceOf[InMemoryMapState[MK, MV]].iterator().toMap } - - def getAllKeysWithTimers[K](): Iterator[K] = { - timers.getAllKeysWithTimers().map(_.asInstanceOf[K]) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala index 496c7409eef9..878ea4dc092f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala @@ -16,18 +16,10 @@ */ package org.apache.spark.sql.streaming -import java.sql.Timestamp -import java.time.{Clock, Duration, Instant, ZoneId} - import scala.reflect.ClassTag -import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.ImplicitGroupingKeyTracker import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.testing.InMemoryStatefulProcessorHandleImpl -import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.{ - ExpiredTimerInfoImpl, - TimerValuesImpl -} /** * Testing utility for transformWithState stateful processors. Provides in-memory state management @@ -45,19 +37,10 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwith */ class TwsTester[K, I, O]( val processor: StatefulProcessor[K, I, O], - val clock: Clock = Clock.systemUTC(), - val timeMode: TimeMode = TimeMode.None, - val outputMode: OutputMode = OutputMode.Append, val initialState: List[(K, Any)] = List()) { - private val handle = new InMemoryStatefulProcessorHandleImpl(timeMode, null, clock) - - private var eventTimeFunc: (I => Timestamp) = null - private var delayThresholdMs: Long = 0 - private var currentWatermarkMs: Option[Long] = None - + private val handle = new InMemoryStatefulProcessorHandleImpl() processor.setHandle(handle) - processor.init(outputMode, timeMode) - + processor.init(OutputMode.Append, TimeMode.None) processor match { case p: StatefulProcessorWithInitialState[K @unchecked, I @unchecked, O @unchecked, s] => handleInitialState[s]() @@ -65,13 +48,11 @@ class TwsTester[K, I, O]( } private def handleInitialState[S](): Unit = { - val timerValues = new TimerValuesImpl(Some(clock.instant().toEpochMilli()), None) val p = processor.asInstanceOf[StatefulProcessorWithInitialState[K, I, O, S]] initialState.foreach { case (key, state) => ImplicitGroupingKeyTracker.setImplicitKey(key) - p.handleInitialState(key, state.asInstanceOf[S], timerValues) - ImplicitGroupingKeyTracker.removeImplicitKey() + p.handleInitialState(key, state.asInstanceOf[S], null) } } @@ -85,66 +66,14 @@ class TwsTester[K, I, O]( * @return all output rows produced by the processor */ def test(input: List[(K, I)]): List[O] = { - val currentTimeMs: Long = clock.instant().toEpochMilli() - var timerValues = new TimerValuesImpl(Some(currentTimeMs), currentWatermarkMs) var ans: List[O] = List() - val filteredInput = filterLateEvents(input) - - for ((key, v) <- filteredInput.groupBy(_._1)) { + for ((key, v) <- input.groupBy(_._1)) { ImplicitGroupingKeyTracker.setImplicitKey(key) - ans = ans ++ processor.handleInputRows(key, v.map(_._2).iterator, timerValues).toList - ImplicitGroupingKeyTracker.removeImplicitKey() - } - - updateWatermark(input) - timerValues = new TimerValuesImpl(Some(currentTimeMs), currentWatermarkMs) - ans ++ handleExpiredTimers(timerValues) - } - - // Filters late events in EventTime mode. - private def filterLateEvents(input: List[(K, I)]): List[(K, I)] = { - if (timeMode != TimeMode.EventTime || !currentWatermarkMs.isDefined) { - return input - } - require(eventTimeFunc != null, "call withWatermark if timeMode is EventTime") - input.filter { case (_, row) => eventTimeFunc(row).getTime() >= currentWatermarkMs.get } - } - - private def handleExpiredTimers(timerValues: TimerValues): List[O] = { - if (timeMode == TimeMode.None) { - return List() - } - val currentTimeMs: Long = - if (timeMode == TimeMode.EventTime) timerValues.getCurrentWatermarkInMs() - else timerValues.getCurrentProcessingTimeInMs() - - var ans: List[O] = List() - for (key <- handle.getAllKeysWithTimers[K]()) { - ImplicitGroupingKeyTracker.setImplicitKey(key) - val expiredTimers: List[Long] = handle.listTimers().filter(_ <= currentTimeMs).toList - for (timerExpiryTimeMs <- expiredTimers) { - val expiredTimerInfo = new ExpiredTimerInfoImpl(Some(timerExpiryTimeMs)) - ans = ans ++ processor.handleExpiredTimer(key, timerValues, expiredTimerInfo).toList - handle.deleteTimer(timerExpiryTimeMs) - } - ImplicitGroupingKeyTracker.removeImplicitKey() - } + ans = ans ++ processor.handleInputRows(key, v.map(_._2).iterator, null).toList + } ans } - private def updateWatermark(input: List[(K, I)]): Unit = { - if (timeMode != TimeMode.EventTime || input.isEmpty) { - return - } - require(eventTimeFunc != null, "call withWatermark if timeMode is EventTime") - currentWatermarkMs = Some( - math.max( - currentWatermarkMs.getOrElse(0L), - input.map(v => eventTimeFunc(v._2).getTime()).max - delayThresholdMs - ) - ) - } - /** * Convenience method to process a single input row for a given key. * @@ -280,45 +209,4 @@ class TwsTester[K, I, O]( ImplicitGroupingKeyTracker.removeImplicitKey() return result } - - /** - * Sets watermark for EventTime time mode. - * - * @param eventTime function used to extract timestamp column from input rows. - * @param delayThreshold a string specifying the minimum delay to wait to data to arrive late, - * relative to the latest record that has been processed in the form of an interval (e.g. - * "1 minute" or "5 hours") - */ - def withWatermark(eventTime: (I => Timestamp), delayThreshold: String): Unit = { - require(timeMode == TimeMode.EventTime, "withWatermark is only usable with TimeMode.EventTime") - val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold) - require( - !IntervalUtils.isNegative(parsedDelay), - s"delay threshold ($delayThreshold) should not be negative." - ) - eventTimeFunc = eventTime - delayThresholdMs = - IntervalUtils.getDuration(parsedDelay, java.util.concurrent.TimeUnit.MILLISECONDS, 30) - } -} - -object TwsTester { - - /** Fake implementation of {@code java.time.CLock} to be used with TwsTester to simulate time. */ - class TestClock( - var currentInstant: Instant = Instant.EPOCH, - zone: ZoneId = ZoneId.systemDefault()) - extends Clock { - override def getZone: ZoneId = zone - override def withZone(zone: ZoneId): Clock = new TestClock(currentInstant, zone) - override def instant(): Instant = currentInstant - - def setInstant(instant: Instant): Unit = { - currentInstant = instant - } - - def advanceBy(duration: Duration): Unit = { - currentInstant = currentInstant.plus(duration) - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala index 196b33934e3c..4b5c8a71b63c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala @@ -16,9 +16,6 @@ */ package org.apache.spark.sql.streaming -import java.sql.Timestamp -import java.time.{Duration, Instant} - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.streaming.runtime.MemoryStream @@ -182,71 +179,6 @@ class TwsTesterSuite extends SparkFunSuite { assert(tester.peekMapState[String, Long]("frequencies", "user4") == Map()) } - test("TwsTester should expire old value state according to TTL") { - val processor = new RunningCountProcessor[String](TTLConfig(Duration.ofSeconds(100))) - val testClock = new TwsTester.TestClock(Instant.EPOCH) - val tester = new TwsTester(processor, testClock) - - tester.testOneRow("key1", "b") - assert(tester.peekValueState[Long]("count", "key1").get == 1L) - testClock.advanceBy(Duration.ofSeconds(101)) - assert(tester.peekValueState[Long]("count", "key1").isEmpty) - } - - test("TwsTester should expire old list state according to TTL") { - val testClock = new TwsTester.TestClock(Instant.EPOCH) - val ttlConfig = TTLConfig(Duration.ofSeconds(100)) - val processor = new TopKProcessor(2, ttlConfig) - val tester = new TwsTester(processor, testClock) - - tester.testOneRow("key1", ("a", 1.0)) - tester.testOneRow("key2", ("a", 1.5)) - assert(tester.peekListState[Double]("topK", "key1") == List(1.0)) - assert(tester.peekListState[Double]("topK", "key2") == List(1.5)) - - testClock.advanceBy(Duration.ofSeconds(50)) - tester.testOneRow("key2", ("a", 2.0)) - assert(tester.peekListState[Double]("topK", "key1") == List(1.0)) - assert(tester.peekListState[Double]("topK", "key2") == List(2.0, 1.5)) - - testClock.advanceBy(Duration.ofSeconds(51)) - assert(tester.peekListState[Double]("topK", "key1") == List()) - assert(tester.peekListState[Double]("topK", "key2") == List(2.0, 1.5)) - - testClock.advanceBy(Duration.ofSeconds(50)) - assert(tester.peekListState[Double]("topK", "key1") == List()) - assert(tester.peekListState[Double]("topK", "key2") == List()) - } - - test("TwsTester should expire old map state according to TTL") { - val testClock = new TwsTester.TestClock(Instant.EPOCH) - val ttlConfig = TTLConfig(Duration.ofSeconds(100)) - val processor = new WordFrequencyProcessor(ttlConfig) - val tester = new TwsTester(processor, testClock) - - tester.testOneRow("key1", ("a", "spark")) - tester.testOneRow("key2", ("a", "beta")) - assert(tester.peekMapState[String, Long]("frequencies", "key1") == Map("spark" -> 1L)) - assert(tester.peekMapState[String, Long]("frequencies", "key2") == Map("beta" -> 1L)) - - testClock.advanceBy(Duration.ofSeconds(50)) - tester.testOneRow("key2", ("a", "spark")) - assert(tester.peekMapState[String, Long]("frequencies", "key1") == Map("spark" -> 1L)) - assert( - tester.peekMapState[String, Long]("frequencies", "key2") == Map("beta" -> 1L, "spark" -> 1L) - ) - - testClock.advanceBy(Duration.ofSeconds(51)) - assert(tester.peekMapState[String, Long]("frequencies", "key1") == Map()) - assert( - tester.peekMapState[String, Long]("frequencies", "key2") == Map("beta" -> 1L, "spark" -> 1L) - ) - - testClock.advanceBy(Duration.ofSeconds(50)) - assert(tester.peekMapState[String, Long]("frequencies", "key1") == Map()) - assert(tester.peekMapState[String, Long]("frequencies", "key2") == Map()) - } - test("TwsTester should test one row with value state") { val processor = new RunningCountProcessor[String]() val tester = new TwsTester(processor) @@ -256,263 +188,6 @@ class TwsTesterSuite extends SparkFunSuite { assert(newState == 11L) } - test("TwsTester should handle session timeout with timer") { - val testClock = new TwsTester.TestClock(Instant.ofEpochMilli(10000L)) - val processor = new SessionTimeoutProcessor(timeoutDurationMs = 5000L) - val tester = new TwsTester(processor, testClock, timeMode = TimeMode.ProcessingTime) - - // First activity - should register timer at 15000ms - val result1 = tester.test(List(("user1", ("user1", "login")))) - assert(result1 == List(("user1", "ACTIVITY", 1L))) - assert(tester.peekValueState[Long]("activityCount", "user1").get == 1L) - assert(tester.peekValueState[Long]("timer", "user1").get == 15000L) - - // Second activity before timeout - should update timer to 20000ms - testClock.advanceBy(Duration.ofMillis(3000L)) // now at 13000ms - val result2 = tester.test(List(("user1", ("user1", "click")))) - assert(result2 == List(("user1", "ACTIVITY", 2L))) - assert(tester.peekValueState[Long]("activityCount", "user1").get == 2L) - assert(tester.peekValueState[Long]("timer", "user1").get == 18000L) - - // Advance time past timeout - timer should fire - testClock.advanceBy(Duration.ofMillis(6000L)) // now at 19000ms, timer at 18000ms should fire - val result3 = tester.test(List()) // empty input, but timer should fire - assert(result3 == List(("user1", "SESSION_TIMEOUT", 2L))) - assert(tester.peekValueState[Long]("activityCount", "user1").isEmpty) - assert(tester.peekValueState[Long]("timer", "user1").isEmpty) - } - - test("TwsTester should process input before timers") { - val testClock = new TwsTester.TestClock(Instant.ofEpochMilli(10000L)) - val processor = new SessionTimeoutProcessor(timeoutDurationMs = 5000L) - val tester = new TwsTester(processor, testClock, timeMode = TimeMode.ProcessingTime) - - // Register initial activity and timer - tester.test(List(("user1", ("user1", "start")))) - assert(tester.peekValueState[Long]("activityCount", "user1").get == 1L) - - // Advance past timer expiry - testClock.advanceBy(Duration.ofMillis(6000L)) // now at 16000ms, timer at 15000ms expired - - // Process new input - input should be processed BEFORE timer fires - val result = tester.test(List(("user1", ("user1", "new_activity")))) - - // Input is processed first: increments count to 2, deletes old timer, registers new one - // The expired timer at 15000ms is deleted during input processing, so it never fires - assert(result.length == 1) - assert(result(0) == ("user1", "ACTIVITY", 2L)) // Input processed, count incremented from 1 to 2 - - // After processing, should have updated state with new timer - assert(tester.peekValueState[Long]("activityCount", "user1").get == 2L) - assert(tester.peekValueState[Long]("timer", "user1").get == 21000L) // 16000 + 5000 - } - - test("TwsTester should handle multiple timers in same batch") { - val testClock = new TwsTester.TestClock(Instant.ofEpochMilli(10000L)) - val processor = new MultiTimerProcessor() - val tester = new TwsTester(processor, testClock, timeMode = TimeMode.ProcessingTime) - - // Register all three timers (SHORT=11000ms, MEDIUM=13000ms, LONG=15000ms) - val result1 = tester.test(List(("user1", ("user1", "start")))) - assert(result1.isEmpty) - - // Advance to fire SHORT timer only - testClock.advanceBy(Duration.ofMillis(1500L)) // now at 11500ms - val result2 = tester.test(List()) - assert(result2.length == 1) - assert(result2(0) == ("user1", "SHORT", 11000L)) - - // Advance to fire MEDIUM and LONG timers together - testClock.advanceBy(Duration.ofMillis(4000L)) // now at 15500ms - val result3 = tester.test(List()) - - // Timers should fire in order of expiry time - assert(result3.length == 2) - assert(result3(0) == ("user1", "MEDIUM", 13000L)) - assert(result3(1) == ("user1", "LONG", 15000L)) - } - - test("TwsTester should not process timers twice") { - val testClock = new TwsTester.TestClock(Instant.ofEpochMilli(10000L)) - val processor = new SessionTimeoutProcessor(timeoutDurationMs = 5000L) - val tester = new TwsTester(processor, testClock, timeMode = TimeMode.ProcessingTime) - - // Register timer at 15000ms - tester.test(List(("user1", ("user1", "start")))) - assert(tester.peekValueState[Long]("activityCount", "user1").get == 1L) - - // Advance past timer and process - timer fires - testClock.advanceBy(Duration.ofMillis(6000L)) // now at 16000ms - val result1 = tester.test(List()) - assert(result1.length == 1) - assert(result1(0) == ("user1", "SESSION_TIMEOUT", 1L)) - assert(tester.peekValueState[Long]("activityCount", "user1").isEmpty) - - // Process again at same time - timer should NOT fire again - val result2 = tester.test(List()) - assert(result2.isEmpty) - - // Process again with even later time - timer should still not fire - testClock.advanceBy(Duration.ofMillis(10000L)) // now at 26000ms - val result3 = tester.test(List()) - assert(result3.isEmpty) - } - - test("TwsTester should handle timers for multiple keys independently") { - val testClock = new TwsTester.TestClock(Instant.ofEpochMilli(10000L)) - val processor = new SessionTimeoutProcessor(timeoutDurationMs = 5000L) - val tester = new TwsTester(processor, testClock, timeMode = TimeMode.ProcessingTime) - - // Register activities for two users at different times - tester.test(List(("user1", ("user1", "start")))) // timer at 15000ms - testClock.advanceBy(Duration.ofMillis(2000L)) // now at 12000ms - tester.test(List(("user2", ("user2", "start")))) // timer at 17000ms - - // Advance to fire only user1's timer - testClock.advanceBy(Duration.ofMillis(4000L)) // now at 16000ms - val result1 = tester.test(List()) - assert(result1.length == 1) - assert(result1(0) == ("user1", "SESSION_TIMEOUT", 1L)) - assert(tester.peekValueState[Long]("activityCount", "user1").isEmpty) - assert(tester.peekValueState[Long]("activityCount", "user2").get == 1L) // user2 still active - - // Advance to fire user2's timer - testClock.advanceBy(Duration.ofMillis(2000L)) // now at 18000ms - val result2 = tester.test(List()) - assert(result2.length == 1) - assert(result2(0) == ("user2", "SESSION_TIMEOUT", 1L)) - assert(tester.peekValueState[Long]("activityCount", "user2").isEmpty) - } - - test("TwsTester should handle timer deletion correctly") { - val testClock = new TwsTester.TestClock(Instant.ofEpochMilli(10000L)) - val processor = new SessionTimeoutProcessor(timeoutDurationMs = 5000L) - val tester = new TwsTester(processor, testClock, timeMode = TimeMode.ProcessingTime) - - // Register initial timer at 15000ms - tester.test(List(("user1", ("user1", "start")))) - assert(tester.peekValueState[Long]("timer", "user1").get == 15000L) - - // New activity deletes old timer and registers new one at 18000ms - testClock.advanceBy(Duration.ofMillis(3000L)) // now at 13000ms - tester.test(List(("user1", ("user1", "activity")))) - assert(tester.peekValueState[Long]("timer", "user1").get == 18000L) - - // Advance past original timer time (15000ms) but before new timer (18000ms) - testClock.advanceBy(Duration.ofMillis(3000L)) // now at 16000ms - val result = tester.test(List()) - // Old timer at 15000ms should NOT fire since it was deleted - assert(result.isEmpty) - assert(tester.peekValueState[Long]("activityCount", "user1").get == 2L) // still active - - // Advance to EXACTLY the new timer time (18000ms). - // This tests that timers fire when current time equals expiry time. - testClock.advanceBy(Duration.ofMillis(2000L)) - val result2 = tester.test(List()) - assert(result2.length == 1) - assert(result2(0) == ("user1", "SESSION_TIMEOUT", 2L)) - } - - test("TwsTester should handle EventTime mode with watermark") { - val processor = new EventTimeWindowProcessor(windowDurationMs = 10000L) - val tester = new TwsTester(processor, timeMode = TimeMode.EventTime) - - // Configure watermark: extract timestamp from input, 2 second delay - // In real Spark, the watermark column must be of Timestamp type - tester.withWatermark( - (input: (String, Timestamp)) => input._2, - "2 seconds" - ) - - // Batch 1: Process events with timestamps 10000, 12000, 15000 - // Max event time = 15000, watermark after batch = 15000 - 2000 = 13000 - val result1 = tester.test( - List( - ("user1", ("event1", new Timestamp(10000L))), - ("user1", ("event2", new Timestamp(12000L))), - ("user1", ("event3", new Timestamp(15000L))) - ) - ) - // First batch: registers timer at 20000 (10000 + 10000), no timers fire yet - assert(result1 == List(("user1", "WINDOW_START", 3L))) - assert(tester.peekValueState[Long]("eventCount", "user1").get == 3L) - assert(tester.peekValueState[Long]("windowEndTime", "user1").get == 20000L) - - // Batch 2: Process more events with timestamps 18000, 20000 - // Watermark before batch = 13000, so timer at 20000 doesn't fire yet - // Max event time = 20000, watermark after batch = 20000 - 2000 = 18000 - // Timer at 20000 still doesn't fire (watermark < timer) - val result2 = tester.test( - List( - ("user1", ("event4", new Timestamp(18000L))), - ("user1", ("event5", new Timestamp(20000L))) - ) - ) - assert(result2 == List(("user1", "WINDOW_CONTINUE", 5L))) - assert(tester.peekValueState[Long]("eventCount", "user1").get == 5L) - - // Batch 3: Process event with timestamp 23000 - // Watermark before batch = 18000 (for filtering late events) - // Max event time = 23000, watermark after batch = 23000 - 2000 = 21000 - // Timer at 20000 FIRES because updated watermark (21000) > timer (20000) - // Input is processed first, then timer fires - val result3 = tester.test(List(("user1", ("event6", new Timestamp(23000L))))) - // Input is processed first (increments count to 6), then timer fires (outputs final count 6) - assert(result3.length == 2) - assert(result3(0) == ("user1", "WINDOW_CONTINUE", 6L)) // Input processed first - assert(result3(1) == ("user1", "WINDOW_END", 6L)) // Timer fired with count after input - // Timer clears state, so no state should exist after this batch - assert(tester.peekValueState[Long]("eventCount", "user1").isEmpty) - assert(tester.peekValueState[Long]("windowEndTime", "user1").isEmpty) - - // Batch 4: Process another event with timestamp 25000 - // State was cleared in batch 3, so this starts a new window - // Watermark before batch = 21000 - // Max event time = 25000, watermark after batch = 25000 - 2000 = 23000 - val result4 = tester.test(List(("user1", ("event7", new Timestamp(25000L))))) - // Since state was cleared, this starts a new window - assert(result4 == List(("user1", "WINDOW_START", 1L))) - assert(tester.peekValueState[Long]("eventCount", "user1").get == 1L) - assert(tester.peekValueState[Long]("windowEndTime", "user1").get == 35000L) // 25000 + 10000 - } - - test("TwsTester should filter late events based on watermark") { - val processor = new EventTimeWindowProcessor(windowDurationMs = 10000L) - val tester = new TwsTester(processor, timeMode = TimeMode.EventTime) - - // Configure watermark with 2 second delay - tester.withWatermark( - (input: (String, Timestamp)) => input._2, - "2 seconds" - ) - - // Batch 1: Process events with timestamps 10000, 12000, 15000 - // Max event time = 15000, watermark after batch = 15000 - 2000 = 13000 - val result1 = tester.test( - List( - ("user1", ("event1", new Timestamp(10000L))), - ("user1", ("event2", new Timestamp(12000L))), - ("user1", ("event3", new Timestamp(15000L))) - ) - ) - assert(result1 == List(("user1", "WINDOW_START", 3L))) - - // Batch 2: Send mix of late and on-time events - // Watermark is currently 13000 - // Event at 11000 is late (< 13000) and should be filtered - // Event at 14000 is on-time (>= 13000) and should be processed - val result2 = tester.test( - List( - ("user1", ("late_event", new Timestamp(11000L))), // LATE - should be filtered - ("user1", ("on_time_event", new Timestamp(14000L))) // ON-TIME - should be processed - ) - ) - - // Only the on-time event should be processed, so count goes from 3 to 4 - assert(result2 == List(("user1", "WINDOW_CONTINUE", 4L))) - assert(tester.peekValueState[Long]("eventCount", "user1").get == 4L) - } - test("TwsTester should call handleInitialState") { val processor = new RunningCountProcessor[String]() val tester = new TwsTester(processor, initialState = List(("a", 10L), ("b", 20L))) @@ -564,9 +239,9 @@ class TwsTesterFuzzTestSuite extends StreamTest { override protected val enableAutoThreadAudit = false /** - * Asserts that {@code tester} is equivalent to streaming query transforming {@code inputStream} - * to {@code result}, when both are fed with data from {@code batches}. - */ + * Asserts that {@code tester} is equivalent to streaming query transforming {@code inputStream} + * to {@code result}, when both are fed with data from {@code batches}. + */ def checkTwsTester[ K: org.apache.spark.sql.Encoder, I: org.apache.spark.sql.Encoder, @@ -581,7 +256,7 @@ class TwsTesterFuzzTestSuite extends StreamTest { ) { val expectedResults: List[List[O]] = batches.map(batch => tester.test(batch)).toList assert(batches.size == expectedResults.size) - + val actions: Seq[StreamAction] = (batches zip expectedResults).flatMap { case (batch, expected) => Seq( @@ -596,7 +271,7 @@ class TwsTesterFuzzTestSuite extends StreamTest { /** * Asserts that {@code tester} processes given {@code input} in the same way as Spark streaming * query with {@code transformWithState} would. - * + * * This is simplified version of {@code checkTwsTester} for the case where there is only one batch * and no timers (time mode is TimeMode.None). */ @@ -649,54 +324,4 @@ class TwsTesterFuzzTestSuite extends StreamTest { val processor = new WordFrequencyProcessor() checkTwsTesterOneBatch(processor, input) } - - test("fuzz test with EventTimeWindowProcessor") { - val inputStream = MemoryStream[(String, (String, Timestamp))] - val processor = new EventTimeWindowProcessor(windowDurationMs = 10000L) - val result = inputStream - .toDS() - .select($"_1", $"_2", $"_2._2".as("timestamp")) - .withWatermark("timestamp", "2 seconds") - .as[(String, (String, Timestamp), Timestamp)] - .groupByKey(_._1) - .mapValues(_._2) - .transformWithState(processor, TimeMode.EventTime(), OutputMode.Append()) - - // Generate 10 random batches, each with ~100 events, 10 users, with timestamps increasing. - val random = new scala.util.Random(0) - val numBatches = 10 - val numUsers = 10 - val eventsPerBatch = 100 - val eventGapMs = 1000L - val userIds = (1 to numUsers).map(i => s"user$i").toArray - - var currentTimestamp = 0L - val batches: List[List[(String, (String, Timestamp))]] = - (0 until numBatches).map { batchIdx => - val batchStart = currentTimestamp - val usersInThisBatch = random.shuffle(userIds.toList) - // Assign variable number of events to each user in this batch - val perUserEvents = Array.fill(numUsers)(random.nextInt(5) + 5) // 5-9 events per user - val events = - usersInThisBatch.zip(perUserEvents).flatMap { - case (user, numEventsForUser) => - (0 until numEventsForUser).map { evtIdx => - // Make timestamp within this batch but always increasing overall - val ts = new Timestamp(currentTimestamp) - val evtId = s"event${batchIdx}_${user}_$evtIdx" - currentTimestamp += eventGapMs - (user, (evtId, ts)) - } - } - random - .shuffle(events) - .take(eventsPerBatch) - .toList // Shuffle and trim to ~100 events per batch - }.toList - - val tester = new TwsTester(processor, timeMode = TimeMode.EventTime()) - tester.withWatermark(_._2, "2 seconds") - - checkTwsTester(tester, batches, inputStream, result) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/EventTimeWindowProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/EventTimeWindowProcessor.scala deleted file mode 100644 index 717b6747abb6..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/EventTimeWindowProcessor.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.streaming.processors - -import java.sql.Timestamp - -import org.apache.spark.sql.Encoders -import org.apache.spark.sql.streaming.{ExpiredTimerInfo, OutputMode, StatefulProcessor, TimeMode, TimerValues, TTLConfig, ValueState} - -/** - * Event time window processor that demonstrates event time timer usage. - * - * Input: (eventId, eventTime) as (String, Timestamp) - * Output: (userId, status, count) as (String, String, Long) - * - * Behavior: - * - Tracks event count per user window - * - On first event, registers a timer for windowDurationMs from the first event time - * - Accumulates events in the window - * - When timer expires (based on watermark), emits window summary and starts new window - */ -class EventTimeWindowProcessor(val windowDurationMs: Long = 10000L) - extends StatefulProcessor[String, (String, Timestamp), (String, String, Long)] { - - @transient private var eventCountState: ValueState[Long] = _ - @transient private var windowEndTimeState: ValueState[Long] = _ - - override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { - eventCountState = - getHandle.getValueState[Long]("eventCount", Encoders.scalaLong, TTLConfig.NONE) - windowEndTimeState = - getHandle.getValueState[Long]("windowEndTime", Encoders.scalaLong, TTLConfig.NONE) - } - - override def handleInputRows( - key: String, - inputRows: Iterator[(String, Timestamp)], - timerValues: TimerValues - ): Iterator[(String, String, Long)] = { - val events = inputRows.toList - val currentCount = if (eventCountState.exists()) eventCountState.get() else 0L - val newCount = currentCount + events.size - eventCountState.update(newCount) - - // If this is the first event in a window, register timer - if (!windowEndTimeState.exists()) { - val firstEventTime = events.head._2.getTime - val windowEnd = firstEventTime + windowDurationMs - getHandle.registerTimer(windowEnd) - windowEndTimeState.update(windowEnd) - Iterator.single((key, "WINDOW_START", newCount)) - } else { - Iterator.single((key, "WINDOW_CONTINUE", newCount)) - } - } - - override def handleExpiredTimer( - key: String, - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo - ): Iterator[(String, String, Long)] = { - val count = if (eventCountState.exists()) eventCountState.get() else 0L - - // Clear window state - eventCountState.clear() - windowEndTimeState.clear() - - Iterator.single((key, "WINDOW_END", count)) - } -} - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/MultiTimerProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/MultiTimerProcessor.scala deleted file mode 100644 index b9d3eb04e46d..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/MultiTimerProcessor.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.streaming.processors - -import org.apache.spark.sql.Encoders -import org.apache.spark.sql.streaming.{ExpiredTimerInfo, MapState, OutputMode, StatefulProcessor, TimeMode, TimerValues, TTLConfig} - -/** - * Multi-timer processor that registers multiple timers with different delays. - * - * Input: (userId, command) as (String, String) - * Output: (userId, timerType, timestamp) as (String, String, Long) - * - * On first input, registers three timers: SHORT (1s), MEDIUM (3s), LONG (5s) - * When each timer expires, emits the timer type and timestamp - */ -class MultiTimerProcessor - extends StatefulProcessor[String, (String, String), (String, String, Long)] { - - @transient private var timerTypeMapState: MapState[Long, String] = _ - - override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { - timerTypeMapState = getHandle.getMapState[Long, String]( - "timerTypeMap", - Encoders.scalaLong, - Encoders.STRING, - TTLConfig.NONE - ) - } - - override def handleInputRows( - key: String, - inputRows: Iterator[(String, String)], - timerValues: TimerValues - ): Iterator[(String, String, Long)] = { - inputRows.size // consume iterator - - val baseTime = timerValues.getCurrentProcessingTimeInMs() - val shortTimer = baseTime + 1000L - val mediumTimer = baseTime + 3000L - val longTimer = baseTime + 5000L - - getHandle.registerTimer(shortTimer) // SHORT - 1 second - getHandle.registerTimer(mediumTimer) // MEDIUM - 3 seconds - getHandle.registerTimer(longTimer) // LONG - 5 seconds - - timerTypeMapState.updateValue(shortTimer, "SHORT") - timerTypeMapState.updateValue(mediumTimer, "MEDIUM") - timerTypeMapState.updateValue(longTimer, "LONG") - - Iterator.empty - } - - override def handleExpiredTimer( - key: String, - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo - ): Iterator[(String, String, Long)] = { - val expiryTime = expiredTimerInfo.getExpiryTimeInMs() - val timerType = timerTypeMapState.getValue(expiryTime) - - // Clean up the timer type from the map - timerTypeMapState.removeKey(expiryTime) - - Iterator.single((key, timerType, expiryTime)) - } -} - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/SessionTimeoutProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/SessionTimeoutProcessor.scala deleted file mode 100644 index 06f79720f9f2..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/SessionTimeoutProcessor.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.streaming.processors - -import org.apache.spark.sql.Encoders -import org.apache.spark.sql.streaming.{ExpiredTimerInfo, OutputMode, StatefulProcessor, TimeMode, TimerValues, TTLConfig, ValueState} - -/** - * Session timeout processor that demonstrates timer usage. - * - * Input: (userId, activityType) as (String, String) - * Output: (userId, event, count) as (String, String, Long) - * - * Behavior: - * - Tracks activity count per user session - * - Registers a timeout timer on first activity (5 seconds from current time) - * - Updates timer on each new activity (resets 5-second countdown) - * - When timer expires, emits ("userId", "SESSION_TIMEOUT", activityCount) and clears state - * - Regular activities emit ("userId", "ACTIVITY", currentCount) - */ -class SessionTimeoutProcessor(val timeoutDurationMs: Long = 5000L) - extends StatefulProcessor[String, (String, String), (String, String, Long)] { - - @transient private var activityCountState: ValueState[Long] = _ - @transient private var timerState: ValueState[Long] = _ - - override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { - activityCountState = - getHandle.getValueState[Long]("activityCount", Encoders.scalaLong, TTLConfig.NONE) - timerState = getHandle.getValueState[Long]("timer", Encoders.scalaLong, TTLConfig.NONE) - } - - override def handleInputRows( - key: String, - inputRows: Iterator[(String, String)], - timerValues: TimerValues - ): Iterator[(String, String, Long)] = { - val currentCount = if (activityCountState.exists()) activityCountState.get() else 0L - val newCount = currentCount + inputRows.size - activityCountState.update(newCount) - - // Delete old timer if exists and register new one - if (timerState.exists()) { - getHandle.deleteTimer(timerState.get()) - } - - val newTimerMs = timerValues.getCurrentProcessingTimeInMs() + timeoutDurationMs - getHandle.registerTimer(newTimerMs) - timerState.update(newTimerMs) - - Iterator.single((key, "ACTIVITY", newCount)) - } - - override def handleExpiredTimer( - key: String, - timerValues: TimerValues, - expiredTimerInfo: ExpiredTimerInfo - ): Iterator[(String, String, Long)] = { - val count = if (activityCountState.exists()) activityCountState.get() else 0L - - // Clear session state - activityCountState.clear() - timerState.clear() - - Iterator.single((key, "SESSION_TIMEOUT", count)) - } -} - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/TopKProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/TopKProcessor.scala index 48733cce5e5a..0970b5580db7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/TopKProcessor.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/TopKProcessor.scala @@ -40,7 +40,6 @@ class TopKProcessor(k: Int, ttl: TTLConfig = TTLConfig.NONE) // Load existing list into a buffer val current = ArrayBuffer[Double]() topKState.get().foreach(current += _) - println(s"AAA loaded state for key=$key: $current") // Add new values and recompute top-K inputRows.foreach { @@ -48,7 +47,6 @@ class TopKProcessor(k: Int, ttl: TTLConfig = TTLConfig.NONE) current += score } val updatedTopK = current.sorted(Ordering[Double].reverse).take(k) - println(s"AAA updatedTopK for key=$key: $updatedTopK") // Persist back topKState.clear() From 8eda5ad0249e908db735f8f24242cfd505a72de2 Mon Sep 17 00:00:00 2001 From: Dmytro Fedoriaka Date: Fri, 21 Nov 2025 21:53:46 +0000 Subject: [PATCH 3/7] Add tests for all state methods --- .../InMemoryStatefulProcessorHandleImpl.scala | 171 ++++++------------ .../spark/sql/streaming/TwsTester.scala | 80 +------- .../spark/sql/streaming/TwsTesterSuite.scala | 51 ++++++ .../processors/AllMethodsTestProcessor.scala | 96 ++++++++++ 4 files changed, 216 insertions(+), 182 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala index 3da55a58dd8f..bbec635d413b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala @@ -26,142 +26,100 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwith import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.QueryInfoImpl import org.apache.spark.sql.streaming.{ListState, MapState, QueryInfo, StatefulProcessorHandle, TTLConfig, ValueState} +/** In-memory implementation of ValueState. */ class InMemoryValueState[T] extends ValueState[T] { private val keyToStateValue = mutable.Map[Any, T]() - private def getValue: Option[T] = { - keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) - } + override def exists(): Boolean = + keyToStateValue.contains(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) - override def exists(): Boolean = getValue.isDefined - override def get(): T = getValue.getOrElse(null.asInstanceOf[T]) + override def get(): T = + keyToStateValue.getOrElse( + ImplicitGroupingKeyTracker.getImplicitKeyOption.get, + null.asInstanceOf[T] + ) - override def update(newState: T): Unit = { + override def update(newState: T): Unit = keyToStateValue.put(ImplicitGroupingKeyTracker.getImplicitKeyOption.get, newState) - } - override def clear(): Unit = { + override def clear(): Unit = keyToStateValue.remove(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) - } } +/** In-memory implementation of ListState. */ class InMemoryListState[T] extends ListState[T] { private val keyToStateValue = mutable.Map[Any, mutable.ArrayBuffer[T]]() - private def getList: Option[mutable.ArrayBuffer[T]] = { - keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) - } - - override def exists(): Boolean = getList.isDefined - - override def get(): Iterator[T] = { - getList.orElse(Some(mutable.ArrayBuffer.empty[T])).get.iterator - } - - override def put(newState: Array[T]): Unit = { - keyToStateValue.put( - ImplicitGroupingKeyTracker.getImplicitKeyOption.get, - mutable.ArrayBuffer.empty[T] ++ newState - ) - } + override def exists(): Boolean = + keyToStateValue.contains(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) - override def appendValue(newState: T): Unit = { + private def getList: mutable.ArrayBuffer[T] = { if (!exists()) { keyToStateValue.put( ImplicitGroupingKeyTracker.getImplicitKeyOption.get, mutable.ArrayBuffer.empty[T] ) } - keyToStateValue(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) += newState + keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get).get } - override def appendList(newState: Array[T]): Unit = { - if (!exists()) { - keyToStateValue.put( - ImplicitGroupingKeyTracker.getImplicitKeyOption.get, - mutable.ArrayBuffer.empty[T] - ) - } - keyToStateValue(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) ++= newState - } + override def get(): Iterator[T] = getList.iterator + + override def put(newState: Array[T]): Unit = + keyToStateValue.put( + ImplicitGroupingKeyTracker.getImplicitKeyOption.get, + mutable.ArrayBuffer.empty[T] ++ newState + ) + + override def appendValue(newState: T): Unit = getList += newState + + override def appendList(newState: Array[T]): Unit = getList ++= newState - override def clear(): Unit = { + override def clear(): Unit = keyToStateValue.remove(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) - } } +/** In-memory implementation of MapState. */ class InMemoryMapState[K, V] extends MapState[K, V] { private val keyToStateValue = mutable.Map[Any, mutable.HashMap[K, V]]() - private def getMap: Option[mutable.HashMap[K, V]] = { - keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) - } - - override def exists(): Boolean = getMap.isDefined - - override def getValue(key: K): V = { - getMap - .orElse(Some(mutable.HashMap.empty[K, V])) - .get - .getOrElse(key, null.asInstanceOf[V]) - } - - override def containsKey(key: K): Boolean = { - getMap - .orElse(Some(mutable.HashMap.empty[K, V])) - .get - .contains(key) - } + override def exists(): Boolean = + keyToStateValue.contains(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) - override def updateValue(key: K, value: V): Unit = { + private def getMap: mutable.HashMap[K, V] = { if (!exists()) { keyToStateValue.put( ImplicitGroupingKeyTracker.getImplicitKeyOption.get, mutable.HashMap.empty[K, V] ) } - - keyToStateValue(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) += (key -> value) + keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get).get } - override def iterator(): Iterator[(K, V)] = { - getMap - .orElse(Some(mutable.HashMap.empty[K, V])) - .get - .iterator - } + override def getValue(key: K): V = getMap.getOrElse(key, null.asInstanceOf[V]) - override def keys(): Iterator[K] = { - getMap - .orElse(Some(mutable.HashMap.empty[K, V])) - .get - .keys - .iterator - } + override def containsKey(key: K): Boolean = getMap.contains(key) - override def values(): Iterator[V] = { - getMap - .orElse(Some(mutable.HashMap.empty[K, V])) - .get - .values - .iterator - } + override def updateValue(key: K, value: V): Unit = getMap.put(key, value) - override def removeKey(key: K): Unit = { - getMap - .orElse(Some(mutable.HashMap.empty[K, V])) - .get - .remove(key) - } + override def iterator(): Iterator[(K, V)] = getMap.iterator + + override def keys(): Iterator[K] = getMap.keys.iterator + + override def values(): Iterator[V] = getMap.values.iterator - override def clear(): Unit = { + override def removeKey(key: K): Unit = getMap.remove(key) + + override def clear(): Unit = keyToStateValue.remove(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) - } } -class InMemoryStatefulProcessorHandleImpl() - extends StatefulProcessorHandle { - +/** + * In-memory implementation of StatefulProcessorHandle. + * + * Doesn't support timers and TTL. + */ +class InMemoryStatefulProcessorHandleImpl() extends StatefulProcessorHandle { private val states = mutable.Map[String, Any]() override def getValueState[T]( @@ -175,9 +133,8 @@ class InMemoryStatefulProcessorHandleImpl() .asInstanceOf[InMemoryValueState[T]] } - override def getValueState[T: Encoder](stateName: String, ttlConfig: TTLConfig): ValueState[T] = { + override def getValueState[T: Encoder](stateName: String, ttlConfig: TTLConfig): ValueState[T] = getValueState(stateName, implicitly[Encoder[T]], ttlConfig) - } override def getListState[T]( stateName: String, @@ -190,9 +147,8 @@ class InMemoryStatefulProcessorHandleImpl() .asInstanceOf[InMemoryListState[T]] } - override def getListState[T: Encoder](stateName: String, ttlConfig: TTLConfig): ListState[T] = { + override def getListState[T: Encoder](stateName: String, ttlConfig: TTLConfig): ListState[T] = getListState(stateName, implicitly[Encoder[T]], ttlConfig) - } override def getMapState[K, V]( stateName: String, @@ -206,32 +162,24 @@ class InMemoryStatefulProcessorHandleImpl() .asInstanceOf[InMemoryMapState[K, V]] } - override def getMapState[K: Encoder, V: Encoder]( + override def getMapState[K: Encoder, V: Encoder]( stateName: String, - ttlConfig: TTLConfig - ): MapState[K, V] = { + ttlConfig: TTLConfig): MapState[K, V] = getMapState(stateName, implicitly[Encoder[K]], implicitly[Encoder[V]], ttlConfig) - } - override def getQueryInfo(): QueryInfo = { + override def getQueryInfo(): QueryInfo = new QueryInfoImpl(UUID.randomUUID(), UUID.randomUUID(), 0L) - } - override def registerTimer(expiryTimestampMs: Long): Unit = { + override def registerTimer(expiryTimestampMs: Long): Unit = throw new UnsupportedOperationException("Timers are not supported.") - } - override def deleteTimer(expiryTimestampMs: Long): Unit = { + override def deleteTimer(expiryTimestampMs: Long): Unit = throw new UnsupportedOperationException("Timers are not supported.") - } - override def listTimers(): Iterator[Long] = { + override def listTimers(): Iterator[Long] = throw new UnsupportedOperationException("Timers are not supported.") - } - override def deleteIfExists(stateName: String): Unit = { - states.remove(stateName) - } + override def deleteIfExists(stateName: String): Unit = states.remove(stateName) def setValueState[T](stateName: String, value: T): Unit = { require(states.contains(stateName), s"State $stateName has not been initialized.") @@ -240,8 +188,7 @@ class InMemoryStatefulProcessorHandleImpl() def peekValueState[T](stateName: String): Option[T] = { require(states.contains(stateName), s"State $stateName has not been initialized.") - val value: T = states(stateName).asInstanceOf[InMemoryValueState[T]].get() - Option(value) + Option(states(stateName).asInstanceOf[InMemoryValueState[T]].get()) } def setListState[T](stateName: String, value: List[T])(implicit ct: ClassTag[T]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala index 878ea4dc092f..1007ffcb1f1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala @@ -74,13 +74,7 @@ class TwsTester[K, I, O]( ans } - /** - * Convenience method to process a single input row for a given key. - * - * @param key the grouping key - * @param inputRow the input row to process - * @return all output rows produced by the processor - */ + /** Convenience method to process a single input row for a given key. */ def testOneRow(key: K, inputRow: I): List[O] = test(List((key, inputRow))) /** @@ -120,93 +114,39 @@ class TwsTester[K, I, O]( (outputRows, peekValueState[S](stateName, key).get) } - /** - * Sets the value state for a given key. - * - * @param stateName the name of the value state variable - * @param key the grouping key - * @param value the value to set - * @tparam T the type of the state value - */ + /** Sets the value state for a given key. */ def setValueState[T](stateName: String, key: K, value: T): Unit = { ImplicitGroupingKeyTracker.setImplicitKey(key) handle.setValueState[T](stateName, value) - ImplicitGroupingKeyTracker.removeImplicitKey() } - /** - * Retrieves the value state for a given key without modifying it. - * - * @param stateName the name of the value state variable - * @param key the grouping key - * @tparam T the type of the state value - * @return Some(value) if state exists for the key, None otherwise - */ + /** Retrieves the value state for a given key. */ def peekValueState[T](stateName: String, key: K): Option[T] = { ImplicitGroupingKeyTracker.setImplicitKey(key) - val result: Option[T] = handle.peekValueState[T](stateName) - ImplicitGroupingKeyTracker.removeImplicitKey() - return result + handle.peekValueState[T](stateName) } - /** - * Sets the list state for a given key. - * - * @param stateName the name of the list state variable - * @param key the grouping key - * @param value the list of values to set - * @param ct implicit class tag for type T - * @tparam T the type of elements in the list state - */ + /** Sets the list state for a given key. */ def setListState[T](stateName: String, key: K, value: List[T])(implicit ct: ClassTag[T]): Unit = { ImplicitGroupingKeyTracker.setImplicitKey(key) handle.setListState[T](stateName, value) - ImplicitGroupingKeyTracker.removeImplicitKey() } - /** - * Retrieves the list state for a given key without modifying it. - * - * @param stateName the name of the list state variable - * @param key the grouping key - * @tparam T the type of elements in the list state - * @return the list of values, or an empty list if no state exists for the key - */ + /** Retrieves the list state for a given key. */ def peekListState[T](stateName: String, key: K): List[T] = { ImplicitGroupingKeyTracker.setImplicitKey(key) - val result: List[T] = handle.peekListState[T](stateName) - ImplicitGroupingKeyTracker.removeImplicitKey() - return result + handle.peekListState[T](stateName) } - /** - * Sets the map state for a given key. - * - * @param stateName the name of the map state variable - * @param key the grouping key - * @param value the map of key-value pairs to set - * @tparam MK the type of keys in the map state - * @tparam MV the type of values in the map state - */ + /** Sets the map state for a given key. */ def setMapState[MK, MV](stateName: String, key: K, value: Map[MK, MV]): Unit = { ImplicitGroupingKeyTracker.setImplicitKey(key) handle.setMapState[MK, MV](stateName, value) - ImplicitGroupingKeyTracker.removeImplicitKey() } - /** - * Retrieves the map state for a given key without modifying it. - * - * @param stateName the name of the map state variable - * @param key the grouping key - * @tparam MK the type of keys in the map state - * @tparam MV the type of values in the map state - * @return the map of key-value pairs, or an empty map if no state exists for the key - */ + /** Retrieves the map state for a given key. */ def peekMapState[MK, MV](stateName: String, key: K): Map[MK, MV] = { ImplicitGroupingKeyTracker.setImplicitKey(key) - val result: Map[MK, MV] = handle.peekMapState[MK, MV](stateName) - ImplicitGroupingKeyTracker.removeImplicitKey() - return result + handle.peekMapState[MK, MV](stateName) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala index 4b5c8a71b63c..75fc5e4b0dfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala @@ -222,6 +222,57 @@ class TwsTesterSuite extends SparkFunSuite { ) ) } + + test("TwsTester should exercise all state methods") { + val tester = new TwsTester(new AllMethodsTestProcessor()) + val results = tester.test(List( + ("k", "value-exists"), // false + ("k", "value-set"), // set to 42 + ("k", "value-exists"), // true + ("k", "value-clear"), // clear + ("k", "value-exists"), // false again + ("k", "list-exists"), // false + ("k", "list-append"), // append a, b + ("k", "list-exists"), // true + ("k", "list-append-array"), // append c, d + ("k", "list-get"), // a,b,c,d + ("k", "map-exists"), // false + ("k", "map-add"), // add x=1, y=2, z=3 + ("k", "map-exists"), // true + ("k", "map-keys"), // x,y,z + ("k", "map-values"), // 1,2,3 + ("k", "map-iterator"), // x=1,y=2,z=3 + ("k", "map-remove"), // remove y + ("k", "map-keys"), // x,z + ("k", "map-clear"), // clear map + ("k", "map-exists"), // false + ("k", "delete-state") // delete value state + )) + + assert(results == List( + ("k", "value-exists:false"), + ("k", "value-set:done"), + ("k", "value-exists:true"), + ("k", "value-clear:done"), + ("k", "value-exists:false"), + ("k", "list-exists:false"), + ("k", "list-append:done"), + ("k", "list-exists:true"), + ("k", "list-append-array:done"), + ("k", "list-get:a,b,c,d"), + ("k", "map-exists:false"), + ("k", "map-add:done"), + ("k", "map-exists:true"), + ("k", "map-keys:x,y,z"), + ("k", "map-values:1,2,3"), + ("k", "map-iterator:x=1,y=2,z=3"), + ("k", "map-remove:done"), + ("k", "map-keys:x,z"), + ("k", "map-clear:done"), + ("k", "map-exists:false"), + ("k", "delete-state:done") + )) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala new file mode 100644 index 000000000000..69acbcf03763 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming.processors + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.streaming.{ListState, MapState, OutputMode, StatefulProcessor, TimeMode, TimerValues, TTLConfig, ValueState} + +/** Test processor that exercises all state methods for coverage testing. */ +class AllMethodsTestProcessor extends StatefulProcessor[String, String, (String, String)] { + + @transient private var valueState: ValueState[Int] = _ + @transient private var listState: ListState[String] = _ + @transient private var mapState: MapState[String, Int] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + valueState = getHandle.getValueState[Int]("value", Encoders.scalaInt, TTLConfig.NONE) + listState = getHandle.getListState[String]("list", Encoders.STRING, TTLConfig.NONE) + mapState = + getHandle.getMapState[String, Int]("map", Encoders.STRING, Encoders.scalaInt, TTLConfig.NONE) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues + ): Iterator[(String, String)] = { + val results = scala.collection.mutable.ArrayBuffer[(String, String)]() + + inputRows.foreach { cmd => + cmd match { + case "value-exists" => + results += ((key, s"value-exists:${valueState.exists()}")) + case "value-set" => + valueState.update(42) + results += ((key, "value-set:done")) + case "value-clear" => + valueState.clear() + results += ((key, "value-clear:done")) + case "list-exists" => + results += ((key, s"list-exists:${listState.exists()}")) + case "list-append" => + listState.appendValue("a") + listState.appendValue("b") + results += ((key, "list-append:done")) + case "list-append-array" => + listState.appendList(Array("c", "d")) + results += ((key, "list-append-array:done")) + case "list-get" => + val items = listState.get().toList.mkString(",") + results += ((key, s"list-get:$items")) + case "map-exists" => + results += ((key, s"map-exists:${mapState.exists()}")) + case "map-add" => + mapState.updateValue("x", 1) + mapState.updateValue("y", 2) + mapState.updateValue("z", 3) + results += ((key, "map-add:done")) + case "map-keys" => + val keys = mapState.keys().toList.sorted.mkString(",") + results += ((key, s"map-keys:$keys")) + case "map-values" => + val values = mapState.values().toList.sorted.mkString(",") + results += ((key, s"map-values:$values")) + case "map-iterator" => + val pairs = + mapState.iterator().toList.sortBy(_._1).map(p => s"${p._1}=${p._2}").mkString(",") + results += ((key, s"map-iterator:$pairs")) + case "map-remove" => + mapState.removeKey("y") + results += ((key, "map-remove:done")) + case "map-clear" => + mapState.clear() + results += ((key, "map-clear:done")) + case "delete-state" => + getHandle.deleteIfExists("value") + results += ((key, "delete-state:done")) + } + } + + results.iterator + } +} From 25443be5604664bfeb5a16868321978a3666db30 Mon Sep 17 00:00:00 2001 From: Dmytro Fedoriaka Date: Fri, 21 Nov 2025 22:50:51 +0000 Subject: [PATCH 4/7] Add fuzz test for all state methods --- .../InMemoryStatefulProcessorHandleImpl.scala | 28 ++++++++++++------- .../spark/sql/streaming/TwsTesterSuite.scala | 23 +++++++++++---- .../processors/AllMethodsTestProcessor.scala | 7 +++-- 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala index bbec635d413b..bdf025b47c73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala @@ -63,7 +63,8 @@ class InMemoryListState[T] extends ListState[T] { keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get).get } - override def get(): Iterator[T] = getList.iterator + override def get(): Iterator[T] = + if (exists()) getList.iterator else Iterator.empty override def put(newState: Array[T]): Unit = keyToStateValue.put( @@ -93,22 +94,29 @@ class InMemoryMapState[K, V] extends MapState[K, V] { mutable.HashMap.empty[K, V] ) } - keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get).get + keyToStateValue(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) + } + + private def getMapIfExists: Option[mutable.HashMap[K, V]] = { + keyToStateValue.get(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) } - override def getValue(key: K): V = getMap.getOrElse(key, null.asInstanceOf[V]) + override def getValue(key: K): V = + getMapIfExists.flatMap(_.get(key)).getOrElse(null.asInstanceOf[V]) - override def containsKey(key: K): Boolean = getMap.contains(key) + override def containsKey(key: K): Boolean = getMapIfExists.exists(_.contains(key)) override def updateValue(key: K, value: V): Unit = getMap.put(key, value) - override def iterator(): Iterator[(K, V)] = getMap.iterator + override def iterator(): Iterator[(K, V)] = + getMapIfExists.map(_.iterator).getOrElse(Iterator.empty) - override def keys(): Iterator[K] = getMap.keys.iterator + override def keys(): Iterator[K] = getMapIfExists.map(_.keys.iterator).getOrElse(Iterator.empty) - override def values(): Iterator[V] = getMap.values.iterator + override def values(): Iterator[V] = + getMapIfExists.map(_.values.iterator).getOrElse(Iterator.empty) - override def removeKey(key: K): Unit = getMap.remove(key) + override def removeKey(key: K): Unit = getMapIfExists.foreach(_.remove(key)) override def clear(): Unit = keyToStateValue.remove(ImplicitGroupingKeyTracker.getImplicitKeyOption.get) @@ -116,8 +124,8 @@ class InMemoryMapState[K, V] extends MapState[K, V] { /** * In-memory implementation of StatefulProcessorHandle. - * - * Doesn't support timers and TTL. + * + * Doesn't support timers and TTL. Support directly accessing state. */ class InMemoryStatefulProcessorHandleImpl() extends StatefulProcessorHandle { private val states = mutable.Map[String, Any]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala index 75fc5e4b0dfe..07b6f6a0b89c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala @@ -226,7 +226,7 @@ class TwsTesterSuite extends SparkFunSuite { test("TwsTester should exercise all state methods") { val tester = new TwsTester(new AllMethodsTestProcessor()) val results = tester.test(List( - ("k", "value-exists"), // false + ("k", "value-exists"), // false ("k", "value-set"), // set to 42 ("k", "value-exists"), // true ("k", "value-clear"), // clear @@ -245,8 +245,7 @@ class TwsTesterSuite extends SparkFunSuite { ("k", "map-remove"), // remove y ("k", "map-keys"), // x,z ("k", "map-clear"), // clear map - ("k", "map-exists"), // false - ("k", "delete-state") // delete value state + ("k", "map-exists") // false )) assert(results == List( @@ -269,8 +268,7 @@ class TwsTesterSuite extends SparkFunSuite { ("k", "map-remove:done"), ("k", "map-keys:x,z"), ("k", "map-clear:done"), - ("k", "map-exists:false"), - ("k", "delete-state:done") + ("k", "map-exists:false") )) } } @@ -375,4 +373,19 @@ class TwsTesterFuzzTestSuite extends StreamTest { val processor = new WordFrequencyProcessor() checkTwsTesterOneBatch(processor, input) } + + test("fuzz test for AllMethodsTestProcessor") { + val random = new scala.util.Random(0) + val commands = Array( + "value-exists", "value-set", "value-clear", + "list-exists", "list-append", "list-append-array", "list-get", + "map-exists", "map-add", "map-keys", "map-values", "map-iterator", + "map-remove", "map-clear" + ) + val input = List.fill(500) { + (s"key${random.nextInt(5)}", commands(random.nextInt(commands.length))) + } + val processor = new AllMethodsTestProcessor() + checkTwsTesterOneBatch(processor, input) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala index 69acbcf03763..9c0424cd68ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala @@ -27,6 +27,7 @@ class AllMethodsTestProcessor extends StatefulProcessor[String, String, (String, @transient private var mapState: MapState[String, Int] = _ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + getHandle.deleteIfExists("value") valueState = getHandle.getValueState[Int]("value", Encoders.scalaInt, TTLConfig.NONE) listState = getHandle.getListState[String]("list", Encoders.STRING, TTLConfig.NONE) mapState = @@ -85,9 +86,9 @@ class AllMethodsTestProcessor extends StatefulProcessor[String, String, (String, case "map-clear" => mapState.clear() results += ((key, "map-clear:done")) - case "delete-state" => - getHandle.deleteIfExists("value") - results += ((key, "delete-state:done")) + //case "delete-state" => + // getHandle.deleteIfExists("value") + // results += ((key, "delete-state:done")) } } From 3eb74ac6525d000bd3ab0cf65bd491974cdda754 Mon Sep 17 00:00:00 2001 From: Dmytro Fedoriaka Date: Fri, 21 Nov 2025 23:55:40 +0000 Subject: [PATCH 5/7] Remove redundant public methods --- .../spark/sql/streaming/TwsTester.scala | 40 ---- .../spark/sql/streaming/TwsTesterSuite.scala | 182 ++++++++++-------- .../processors/AllMethodsTestProcessor.scala | 4 - 3 files changed, 107 insertions(+), 119 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala index 1007ffcb1f1c..94800c87ab3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala @@ -74,46 +74,6 @@ class TwsTester[K, I, O]( ans } - /** Convenience method to process a single input row for a given key. */ - def testOneRow(key: K, inputRow: I): List[O] = test(List((key, inputRow))) - - /** - * Processes input rows through the stateful processor, one by one. - * - * This corresponds to running streaming query in real-time mode. {@code handleInputRows} will be - * called once for each row in {@code input}. - * - * @param input list of (key, input row) tuples to process - * @return all output rows produced by the processor - */ - def testRowByRow(input: List[(K, I)]): List[O] = { - var ans: List[O] = List() - for (row <- input) { - ans ++= test(List(row)) - } - ans - } - - /** - * Tests how value state is changed after processing one row. - * - * @param key the grouping key - * @param inputRow the input row to process - * @param stateName the name os value state - * @param stateIn the old value of the value state - * @tparam S the type of value state - * @return output rows produced by the processor and new value of the value state - */ - def testOneRowWithValueState[S]( - key: K, - inputRow: I, - stateName: String, - stateIn: S): (List[O], S) = { - setValueState[S](stateName, key, stateIn) - val outputRows = testOneRow(key, inputRow) - (outputRows, peekValueState[S](stateName, key).get) - } - /** Sets the value state for a given key. */ def setValueState[T](stateName: String, key: K, value: T): Unit = { ImplicitGroupingKeyTracker.setImplicitKey(key) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala index 07b6f6a0b89c..37f2ec270784 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TwsTesterSuite.scala @@ -45,7 +45,7 @@ class TwsTesterSuite extends SparkFunSuite { assert(tester.peekValueState[Long]("count", "key3").get == 1L) assert(tester.peekValueState[Long]("count", "key4").isEmpty) - val ans2 = tester.testOneRow("key1", "q") + val ans2 = tester.test(List(("key1", "q"))) assert(ans2 == List(("key1", 5L))) assert(tester.peekValueState[Long]("count", "key1").get == 5L) assert(tester.peekValueState[Long]("count", "key2").get == 2L) @@ -98,9 +98,9 @@ class TwsTesterSuite extends SparkFunSuite { val tester = new TwsTester(new TopKProcessor(2)) tester.setListState("topK", "a", List(6.0, 5.0)) tester.setListState("topK", "b", List(8.0, 7.0)) - tester.testOneRow("a", ("", 10.0)) - tester.testOneRow("b", ("", 7.5)) - tester.testOneRow("c", ("", 1.0)) + tester.test(List(("a", ("", 10.0)))) + tester.test(List(("b", ("", 7.5)))) + tester.test(List(("c", ("", 1.0)))) assert(tester.peekListState[Double]("topK", "a") == List(10.0, 6.0)) assert(tester.peekListState[Double]("topK", "b") == List(8.0, 7.5)) @@ -161,10 +161,10 @@ class TwsTesterSuite extends SparkFunSuite { tester.setMapState("frequencies", "user2", Map("spark" -> 10L)) // Process new words - tester.testOneRow("user1", ("", "hello")) - tester.testOneRow("user1", ("", "goodbye")) - tester.testOneRow("user2", ("", "spark")) - tester.testOneRow("user3", ("", "new")) + tester.test(List(("user1", ("", "hello")))) + tester.test(List(("user1", ("", "goodbye")))) + tester.test(List(("user2", ("", "spark")))) + tester.test(List(("user3", ("", "new")))) // Verify updated state assert( @@ -179,13 +179,19 @@ class TwsTesterSuite extends SparkFunSuite { assert(tester.peekMapState[String, Long]("frequencies", "user4") == Map()) } - test("TwsTester should test one row with value state") { + test("TwsTester can be used to test step function") { val processor = new RunningCountProcessor[String]() val tester = new TwsTester(processor) - val (rows, newState) = tester.testOneRowWithValueState("key1", "a", "count", 10L) - assert(rows == List(("key1", 11L))) - assert(newState == 11L) + // Example of helper function using TwsTester to inspect how processing a single row changes + // state. + def testStepFunction(key: String, inputRow: String, stateIn: Long): Long = { + tester.setValueState[Long]("count", key, stateIn) + tester.test(List((key, inputRow))) + tester.peekValueState("count", key).get + } + + assert(testStepFunction("key1", "a", 10L) == 11L) } test("TwsTester should call handleInitialState") { @@ -199,6 +205,14 @@ class TwsTesterSuite extends SparkFunSuite { } test("TwsTester should test RunningCountProcessor row-by-row") { + val tester = new TwsTester(new RunningCountProcessor[String]()) + + // Example of helper function to test how TransformWithState processes rows one-by-one, which + // is can be used to simulate real-time mode. + def testRowByRow(input: List[(String, String)]): List[(String, Long)] = { + input.flatMap(row => tester.test(List(row))) + } + val input: List[(String, String)] = List( ("key1", "a"), ("key2", "b"), @@ -208,8 +222,7 @@ class TwsTesterSuite extends SparkFunSuite { ("key1", "c"), ("key3", "q") ) - val tester = new TwsTester(new RunningCountProcessor[String]()) - val ans: List[(String, Long)] = tester.testRowByRow(input) + val ans: List[(String, Long)] = testRowByRow(input) assert( ans == List( ("key1", 1L), @@ -225,51 +238,55 @@ class TwsTesterSuite extends SparkFunSuite { test("TwsTester should exercise all state methods") { val tester = new TwsTester(new AllMethodsTestProcessor()) - val results = tester.test(List( - ("k", "value-exists"), // false - ("k", "value-set"), // set to 42 - ("k", "value-exists"), // true - ("k", "value-clear"), // clear - ("k", "value-exists"), // false again - ("k", "list-exists"), // false - ("k", "list-append"), // append a, b - ("k", "list-exists"), // true - ("k", "list-append-array"), // append c, d - ("k", "list-get"), // a,b,c,d - ("k", "map-exists"), // false - ("k", "map-add"), // add x=1, y=2, z=3 - ("k", "map-exists"), // true - ("k", "map-keys"), // x,y,z - ("k", "map-values"), // 1,2,3 - ("k", "map-iterator"), // x=1,y=2,z=3 - ("k", "map-remove"), // remove y - ("k", "map-keys"), // x,z - ("k", "map-clear"), // clear map - ("k", "map-exists") // false - )) - - assert(results == List( - ("k", "value-exists:false"), - ("k", "value-set:done"), - ("k", "value-exists:true"), - ("k", "value-clear:done"), - ("k", "value-exists:false"), - ("k", "list-exists:false"), - ("k", "list-append:done"), - ("k", "list-exists:true"), - ("k", "list-append-array:done"), - ("k", "list-get:a,b,c,d"), - ("k", "map-exists:false"), - ("k", "map-add:done"), - ("k", "map-exists:true"), - ("k", "map-keys:x,y,z"), - ("k", "map-values:1,2,3"), - ("k", "map-iterator:x=1,y=2,z=3"), - ("k", "map-remove:done"), - ("k", "map-keys:x,z"), - ("k", "map-clear:done"), - ("k", "map-exists:false") - )) + val results = tester.test( + List( + ("k", "value-exists"), // false + ("k", "value-set"), // set to 42 + ("k", "value-exists"), // true + ("k", "value-clear"), // clear + ("k", "value-exists"), // false again + ("k", "list-exists"), // false + ("k", "list-append"), // append a, b + ("k", "list-exists"), // true + ("k", "list-append-array"), // append c, d + ("k", "list-get"), // a,b,c,d + ("k", "map-exists"), // false + ("k", "map-add"), // add x=1, y=2, z=3 + ("k", "map-exists"), // true + ("k", "map-keys"), // x,y,z + ("k", "map-values"), // 1,2,3 + ("k", "map-iterator"), // x=1,y=2,z=3 + ("k", "map-remove"), // remove y + ("k", "map-keys"), // x,z + ("k", "map-clear"), // clear map + ("k", "map-exists") // false + ) + ) + + assert( + results == List( + ("k", "value-exists:false"), + ("k", "value-set:done"), + ("k", "value-exists:true"), + ("k", "value-clear:done"), + ("k", "value-exists:false"), + ("k", "list-exists:false"), + ("k", "list-append:done"), + ("k", "list-exists:true"), + ("k", "list-append-array:done"), + ("k", "list-get:a,b,c,d"), + ("k", "map-exists:false"), + ("k", "map-add:done"), + ("k", "map-exists:true"), + ("k", "map-keys:x,y,z"), + ("k", "map-values:1,2,3"), + ("k", "map-iterator:x=1,y=2,z=3"), + ("k", "map-remove:done"), + ("k", "map-keys:x,z"), + ("k", "map-clear:done"), + ("k", "map-exists:false") + ) + ) } } @@ -291,7 +308,7 @@ class TwsTesterFuzzTestSuite extends StreamTest { * Asserts that {@code tester} is equivalent to streaming query transforming {@code inputStream} * to {@code result}, when both are fed with data from {@code batches}. */ - def checkTwsTester[ + private def checkTwsTesterEndToEnd[ K: org.apache.spark.sql.Encoder, I: org.apache.spark.sql.Encoder, O: org.apache.spark.sql.Encoder]( @@ -318,18 +335,15 @@ class TwsTesterFuzzTestSuite extends StreamTest { } /** - * Asserts that {@code tester} processes given {@code input} in the same way as Spark streaming + * Asserts that {@code tester} processes given {@code batches} in the same way as Spark streaming * query with {@code transformWithState} would. - * - * This is simplified version of {@code checkTwsTester} for the case where there is only one batch - * and no timers (time mode is TimeMode.None). */ - def checkTwsTesterOneBatch[ + private def checkTwsTester[ K: org.apache.spark.sql.Encoder, I: org.apache.spark.sql.Encoder, O: org.apache.spark.sql.Encoder]( processor: StatefulProcessor[K, I, O], - input: List[(K, I)]): Unit = { + batches: List[List[(K, I)]]): Unit = { implicit val tupleEncoder = org.apache.spark.sql.Encoders.tuple( implicitly[org.apache.spark.sql.Encoder[K]], implicitly[org.apache.spark.sql.Encoder[I]] @@ -340,7 +354,15 @@ class TwsTesterFuzzTestSuite extends StreamTest { .groupByKey(_._1) .mapValues(_._2) .transformWithState(processor, TimeMode.None(), OutputMode.Append()) - checkTwsTester(new TwsTester(processor), List(input), inputStream, result) + checkTwsTesterEndToEnd(new TwsTester(processor), batches, inputStream, result) + } + + private def split[T](xs: List[T], numParts: Int): List[List[T]] = { + require(numParts > 0 && xs.size % numParts == 0) + val partSize = xs.size / numParts + (0 until numParts).map { i => + xs.slice(i * partSize, (i + 1) * partSize) + }.toList } test("fuzz test with RunningCountProcessor") { @@ -349,7 +371,7 @@ class TwsTesterFuzzTestSuite extends StreamTest { (s"key${random.nextInt(10)}", random.alphanumeric.take(5).mkString) } val processor = new RunningCountProcessor[String]() - checkTwsTesterOneBatch(processor, input) + checkTwsTester(processor, split(input, 2)) } test("fuzz test with TopKProcessor") { @@ -361,7 +383,7 @@ class TwsTesterFuzzTestSuite extends StreamTest { ) } val processor = new TopKProcessor(5) - checkTwsTesterOneBatch(processor, input) + checkTwsTester(processor, split(input, 2)) } test("fuzz test with WordFrequencyProcessor") { @@ -371,21 +393,31 @@ class TwsTesterFuzzTestSuite extends StreamTest { (s"key${random.nextInt(10)}", ("", words(random.nextInt(words.length)))) } val processor = new WordFrequencyProcessor() - checkTwsTesterOneBatch(processor, input) + checkTwsTester(processor, split(input, 2)) } test("fuzz test for AllMethodsTestProcessor") { val random = new scala.util.Random(0) val commands = Array( - "value-exists", "value-set", "value-clear", - "list-exists", "list-append", "list-append-array", "list-get", - "map-exists", "map-add", "map-keys", "map-values", "map-iterator", - "map-remove", "map-clear" + "value-exists", + "value-set", + "value-clear", + "list-exists", + "list-append", + "list-append-array", + "list-get", + "map-exists", + "map-add", + "map-keys", + "map-values", + "map-iterator", + "map-remove", + "map-clear" ) val input = List.fill(500) { (s"key${random.nextInt(5)}", commands(random.nextInt(commands.length))) } val processor = new AllMethodsTestProcessor() - checkTwsTesterOneBatch(processor, input) + checkTwsTester(processor, split(input, 2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala index 9c0424cd68ac..ca3c77aa8ded 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/processors/AllMethodsTestProcessor.scala @@ -27,7 +27,6 @@ class AllMethodsTestProcessor extends StatefulProcessor[String, String, (String, @transient private var mapState: MapState[String, Int] = _ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { - getHandle.deleteIfExists("value") valueState = getHandle.getValueState[Int]("value", Encoders.scalaInt, TTLConfig.NONE) listState = getHandle.getListState[String]("list", Encoders.STRING, TTLConfig.NONE) mapState = @@ -86,9 +85,6 @@ class AllMethodsTestProcessor extends StatefulProcessor[String, String, (String, case "map-clear" => mapState.clear() results += ((key, "map-clear:done")) - //case "delete-state" => - // getHandle.deleteIfExists("value") - // results += ((key, "delete-state:done")) } } From a5e1bf6f947ab70f33b94d6f8058768d64dea7eb Mon Sep 17 00:00:00 2001 From: Dmytro Fedoriaka Date: Sat, 22 Nov 2025 00:14:33 +0000 Subject: [PATCH 6/7] rename InMemoryStatefulProcessorHandle --- ...HandleImpl.scala => InMemoryStatefulProcessorHandle.scala} | 4 ++-- .../main/scala/org/apache/spark/sql/streaming/TwsTester.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/{InMemoryStatefulProcessorHandleImpl.scala => InMemoryStatefulProcessorHandle.scala} (98%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandle.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandle.scala index bdf025b47c73..723b51e11f0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandle.scala @@ -125,9 +125,9 @@ class InMemoryMapState[K, V] extends MapState[K, V] { /** * In-memory implementation of StatefulProcessorHandle. * - * Doesn't support timers and TTL. Support directly accessing state. + * Doesn't support timers and TTL. Supports directly accessing state. */ -class InMemoryStatefulProcessorHandleImpl() extends StatefulProcessorHandle { +class InMemoryStatefulProcessorHandle() extends StatefulProcessorHandle { private val states = mutable.Map[String, Any]() override def getValueState[T]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala index 94800c87ab3e..8b692e59d23d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import scala.reflect.ClassTag import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.ImplicitGroupingKeyTracker -import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.testing.InMemoryStatefulProcessorHandleImpl +import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.testing.InMemoryStatefulProcessorHandle /** * Testing utility for transformWithState stateful processors. Provides in-memory state management @@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwith class TwsTester[K, I, O]( val processor: StatefulProcessor[K, I, O], val initialState: List[(K, Any)] = List()) { - private val handle = new InMemoryStatefulProcessorHandleImpl() + private val handle = new InMemoryStatefulProcessorHandle() processor.setHandle(handle) processor.init(OutputMode.Append, TimeMode.None) processor match { From 8da18fd95fc8eed5135c05790f30edb034c1e639 Mon Sep 17 00:00:00 2001 From: Dmytro Fedoriaka Date: Sat, 22 Nov 2025 03:25:05 +0000 Subject: [PATCH 7/7] documentation --- .../InMemoryStatefulProcessorHandle.scala | 3 -- .../spark/sql/streaming/TwsTester.scala | 33 +++++++++++++++---- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandle.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandle.scala index 723b51e11f0f..5088865c7497 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandle.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/testing/InMemoryStatefulProcessorHandle.scala @@ -135,7 +135,6 @@ class InMemoryStatefulProcessorHandle() extends StatefulProcessorHandle { valEncoder: Encoder[T], ttlConfig: TTLConfig ): ValueState[T] = { - require(!states.contains(stateName), s"State $stateName already defined.") states .getOrElseUpdate(stateName, new InMemoryValueState[T]()) .asInstanceOf[InMemoryValueState[T]] @@ -149,7 +148,6 @@ class InMemoryStatefulProcessorHandle() extends StatefulProcessorHandle { valEncoder: Encoder[T], ttlConfig: TTLConfig ): ListState[T] = { - require(!states.contains(stateName), s"State $stateName already defined.") states .getOrElseUpdate(stateName, new InMemoryListState[T]()) .asInstanceOf[InMemoryListState[T]] @@ -164,7 +162,6 @@ class InMemoryStatefulProcessorHandle() extends StatefulProcessorHandle { valEncoder: Encoder[V], ttlConfig: TTLConfig ): MapState[K, V] = { - require(!states.contains(stateName), s"State $stateName already defined.") states .getOrElseUpdate(stateName, new InMemoryMapState[K, V]()) .asInstanceOf[InMemoryMapState[K, V]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala index 8b692e59d23d..7bd64910a471 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala @@ -22,15 +22,31 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwith import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.testing.InMemoryStatefulProcessorHandle /** - * Testing utility for transformWithState stateful processors. Provides in-memory state management - * and simplified input processing for unit testing StatefulProcessor implementations. + * Testing utility for transformWithState stateful processors. + * + * This class enables unit testing of StatefulProcessor business logic by simulating the + * behavior of transformWithState. It processes input rows and returns output rows equivalent + * to those that would be produced by the processor in an actual Spark streaming query. + * + * '''Supported:''' + * - Processing input rows and producing output rows via `test()`. + * - Initial state setup via constructor parameter. + * - Direct state manipulation via `setValueState`, `setListState`, `setMapState`. + * - Direct state inspection via `peekValueState`, `peekListState`, `peekMapState`. + * + * '''Not Supported:''' + * - '''Timers''': Only TimeMode.None is supported. If the processor attempts to register or + * use timers (as if in TimeMode.EventTime or TimeMode.ProcessingTime), a NullPointerException + * will be thrown. + * - '''TTL''': State TTL configurations are ignored. All state persists indefinitely. + * + * '''Use Cases:''' + * - '''Primary''': Unit testing business logic in `handleInputRows` implementations. + * - '''Not recommended''': End-to-end testing or performance testing - use actual Spark + * streaming queries for those scenarios. * * @param processor the StatefulProcessor to test - * @param clock the clock to use for time-based operations, defaults to system UTC - * @param timeMode time mode that will be passed to transformWithState (defaults to TimeMode.None) - * @param outputMode output mode that will be passed to transformWithState (defaults to - * OutputMode.Append) - * @param initialState initial state for each key + * @param initialState initial state for each key as a list of (key, state) tuples * @tparam K the type of grouping key * @tparam I the type of input rows * @tparam O the type of output rows @@ -61,6 +77,9 @@ class TwsTester[K, I, O]( * * This corresponds to processing one microbatch. {@code handleInputRows} will be called once for * each key that appears in {@code input}. + * + * To simulate real-time mode, call this method repeatedly in a loop, passing a list with a single + * (key, input row) tuple per call. * * @param input list of (key, input row) tuples to process * @return all output rows produced by the processor