Skip to content

Commit

Permalink
[SPARK-20057][SS] Renamed KeyedState to GroupState in mapGroupsWithState
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Since the state is tied a "group" in the "mapGroupsWithState" operations, its better to call the state "GroupState" instead of a key. This would make it more general if you extends this operation to RelationGroupedDataset and python APIs.

## How was this patch tested?
Existing unit tests.

Author: Tathagata Das <[email protected]>

Closes #17385 from tdas/SPARK-20057.
  • Loading branch information
tdas committed Mar 22, 2017
1 parent 80fd070 commit 82b598b
Show file tree
Hide file tree
Showing 13 changed files with 172 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,31 @@
/**
* Represents the type of timeouts possible for the Dataset operations
* `mapGroupsWithState` and `flatMapGroupsWithState`. See documentation on
* `KeyedState` for more details.
* `GroupState` for more details.
*
* @since 2.2.0
*/
@Experimental
@InterfaceStability.Evolving
public class KeyedStateTimeout {
public class GroupStateTimeout {

/**
* Timeout based on processing time. The duration of timeout can be set for each group in
* `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutDuration()`. See documentation
* on `KeyedState` for more details.
* `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutDuration()`. See documentation
* on `GroupState` for more details.
*/
public static KeyedStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; }
public static GroupStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; }

/**
* Timeout based on event-time. The event-time timestamp for timeout can be set for each
* group in `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutTimestamp()`.
* group in `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutTimestamp()`.
* In addition, you have to define the watermark in the query using `Dataset.withWatermark`.
* When the watermark advances beyond the set timestamp of a group and the group has not
* received any data, then the group times out. See documentation on
* `KeyedState` for more details.
* `GroupState` for more details.
*/
public static KeyedStateTimeout EventTimeTimeout() { return EventTimeTimeout$.MODULE$; }
public static GroupStateTimeout EventTimeTimeout() { return EventTimeTimeout$.MODULE$; }

/** No timeout. */
public static KeyedStateTimeout NoTimeout() { return NoTimeout$.MODULE$; }
public static GroupStateTimeout NoTimeout() { return NoTimeout$.MODULE$; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode }
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode }
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -351,22 +351,22 @@ case class MapGroups(
child: LogicalPlan) extends UnaryNode with ObjectProducer

/** Internal class representing State */
trait LogicalKeyedState[S]
trait LogicalGroupState[S]

/** Types of timeouts used in FlatMapGroupsWithState */
case object NoTimeout extends KeyedStateTimeout
case object ProcessingTimeTimeout extends KeyedStateTimeout
case object EventTimeTimeout extends KeyedStateTimeout
case object NoTimeout extends GroupStateTimeout
case object ProcessingTimeTimeout extends GroupStateTimeout
case object EventTimeTimeout extends GroupStateTimeout

/** Factory for constructing new `MapGroupsWithState` nodes. */
object FlatMapGroupsWithState {
def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputMode: OutputMode,
isMapGroupsWithState: Boolean,
timeout: KeyedStateTimeout,
timeout: GroupStateTimeout,
child: LogicalPlan): LogicalPlan = {
val encoder = encoderFor[S]

Expand Down Expand Up @@ -404,7 +404,7 @@ object FlatMapGroupsWithState {
* @param timeout used to timeout groups that have not received data in a while
*/
case class FlatMapGroupsWithState(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
Expand All @@ -413,7 +413,7 @@ case class FlatMapGroupsWithState(
stateEncoder: ExpressionEncoder[Any],
outputMode: OutputMode,
isMapGroupsWithState: Boolean = false,
timeout: KeyedStateTimeout,
timeout: GroupStateTimeout,
child: LogicalPlan) extends UnaryNode with ObjectProducer {

if (isMapGroupsWithState) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@

package org.apache.spark.sql.streaming;

import org.apache.spark.sql.catalyst.plans.logical.EventTimeTimeout$;
import org.apache.spark.sql.catalyst.plans.logical.NoTimeout$;
import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$;
import org.junit.Test;

public class JavaKeyedStateTimeoutSuite {
public class JavaGroupStateTimeoutSuite {

@Test
public void testTimeouts() {
assert(KeyedStateTimeout.ProcessingTimeTimeout() == ProcessingTimeTimeout$.MODULE$);
assert (GroupStateTimeout.ProcessingTimeTimeout() == ProcessingTimeTimeout$.MODULE$);
assert (GroupStateTimeout.EventTimeTimeout() == EventTimeTimeout$.MODULE$);
assert (GroupStateTimeout.NoTimeout() == NoTimeout$.MODULE$);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import org.apache.spark.annotation.Experimental;
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.streaming.KeyedState;
import org.apache.spark.sql.streaming.GroupState;

/**
* ::Experimental::
Expand All @@ -35,5 +35,5 @@
@Experimental
@InterfaceStability.Evolving
public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable {
Iterator<R> call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
Iterator<R> call(K key, Iterator<V> values, GroupState<S> state) throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import org.apache.spark.annotation.Experimental;
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.streaming.KeyedState;
import org.apache.spark.sql.streaming.GroupState;

/**
* ::Experimental::
Expand All @@ -34,5 +34,5 @@
@Experimental
@InterfaceStability.Evolving
public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable {
R call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
R call(K key, Iterator<V> values, GroupState<S> state) throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator
import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout, OutputMode}
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode}

/**
* :: Experimental ::
Expand Down Expand Up @@ -228,7 +228,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[org.apache.spark.sql.streaming.KeyedState]] for more details.
* See [[org.apache.spark.sql.streaming.GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
Expand All @@ -240,17 +240,17 @@ class KeyValueGroupedDataset[K, V] private[sql](
@Experimental
@InterfaceStability.Evolving
def mapGroupsWithState[S: Encoder, U: Encoder](
func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s))
Dataset[U](
sparkSession,
FlatMapGroupsWithState[K, V, S, U](
flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
OutputMode.Update,
isMapGroupsWithState = true,
KeyedStateTimeout.NoTimeout,
GroupStateTimeout.NoTimeout,
child = logicalPlan))
}

Expand All @@ -262,7 +262,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[org.apache.spark.sql.streaming.KeyedState]] for more details.
* See [[org.apache.spark.sql.streaming.GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
Expand All @@ -275,13 +275,13 @@ class KeyValueGroupedDataset[K, V] private[sql](
@Experimental
@InterfaceStability.Evolving
def mapGroupsWithState[S: Encoder, U: Encoder](
timeoutConf: KeyedStateTimeout)(
func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))
timeoutConf: GroupStateTimeout)(
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s))
Dataset[U](
sparkSession,
FlatMapGroupsWithState[K, V, S, U](
flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
OutputMode.Update,
Expand All @@ -298,7 +298,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[KeyedState]] for more details.
* See [[GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
Expand All @@ -316,7 +316,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
stateEncoder: Encoder[S],
outputEncoder: Encoder[U]): Dataset[U] = {
mapGroupsWithState[S, U](
(key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s)
(key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
)(stateEncoder, outputEncoder)
}

Expand All @@ -328,7 +328,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[KeyedState]] for more details.
* See [[GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
Expand All @@ -346,9 +346,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
func: MapGroupsWithStateFunction[K, V, S, U],
stateEncoder: Encoder[S],
outputEncoder: Encoder[U],
timeoutConf: KeyedStateTimeout): Dataset[U] = {
timeoutConf: GroupStateTimeout): Dataset[U] = {
mapGroupsWithState[S, U](
(key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s)
(key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
)(stateEncoder, outputEncoder)
}

Expand All @@ -360,7 +360,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[KeyedState]] for more details.
* See [[GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
Expand All @@ -375,15 +375,15 @@ class KeyValueGroupedDataset[K, V] private[sql](
@InterfaceStability.Evolving
def flatMapGroupsWithState[S: Encoder, U: Encoder](
outputMode: OutputMode,
timeoutConf: KeyedStateTimeout)(
func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = {
timeoutConf: GroupStateTimeout)(
func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) {
throw new IllegalArgumentException("The output mode of function should be append or update")
}
Dataset[U](
sparkSession,
FlatMapGroupsWithState[K, V, S, U](
func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
func.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
outputMode,
Expand All @@ -400,7 +400,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* For a static batch Dataset, the function will be invoked once per group. For a streaming
* Dataset, the function will be invoked for each group repeatedly in every trigger, and
* updates to each group's state will be saved across invocations.
* See [[KeyedState]] for more details.
* See [[GroupState]] for more details.
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
Expand All @@ -420,8 +420,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
outputMode: OutputMode,
stateEncoder: Encoder[S],
outputEncoder: Encoder[U],
timeoutConf: KeyedStateTimeout): Dataset[U] = {
val f = (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala
timeoutConf: GroupStateTimeout): Dataset[U] = {
val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala
flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
import org.apache.spark.sql.execution.streaming.KeyedStateImpl
import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
import org.apache.spark.sql.execution.streaming.GroupStateImpl
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -355,14 +355,14 @@ case class MapGroupsExec(

object MapGroupsExec {
def apply(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any],
func: (Any, Iterator[Any], LogicalGroupState[Any]) => TraversableOnce[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputObjAttr: Attribute,
child: SparkPlan): MapGroupsExec = {
val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None))
val f = (key: Any, values: Iterator[Any]) => func(key, values, new GroupStateImpl[Any](None))
new MapGroupsExec(f, keyDeserializer, valueDeserializer,
groupingAttributes, dataAttributes, outputObjAttr, child)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.streaming.KeyedStateImpl.NO_TIMESTAMP
import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode}
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.CompletionIterator

Expand All @@ -44,7 +44,7 @@ import org.apache.spark.util.CompletionIterator
* @param batchTimestampMs processing timestamp of the current batch.
*/
case class FlatMapGroupsWithStateExec(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
Expand All @@ -53,13 +53,13 @@ case class FlatMapGroupsWithStateExec(
stateId: Option[OperatorStateId],
stateEncoder: ExpressionEncoder[Any],
outputMode: OutputMode,
timeoutConf: KeyedStateTimeout,
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
override val eventTimeWatermark: Option[Long],
child: SparkPlan
) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport {

import KeyedStateImpl._
import GroupStateImpl._

private val isTimeoutEnabled = timeoutConf != NoTimeout
private val timestampTimeoutAttribute =
Expand Down Expand Up @@ -147,7 +147,7 @@ case class FlatMapGroupsWithStateExec(
private val stateSerializer = {
val encoderSerializer = stateEncoder.namedExpressions
if (isTimeoutEnabled) {
encoderSerializer :+ Literal(KeyedStateImpl.NO_TIMESTAMP)
encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP)
} else {
encoderSerializer
}
Expand Down Expand Up @@ -211,7 +211,7 @@ case class FlatMapGroupsWithStateExec(
val keyObj = getKeyObj(keyRow) // convert key to objects
val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
val stateObjOption = getStateObj(prevStateRowOption)
val keyedState = new KeyedStateImpl(
val keyedState = new GroupStateImpl(
stateObjOption,
batchTimestampMs.getOrElse(NO_TIMESTAMP),
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
Expand Down Expand Up @@ -247,7 +247,7 @@ case class FlatMapGroupsWithStateExec(

if (shouldWriteState) {
if (stateRowToWrite == null) {
// This should never happen because checks in KeyedStateImpl should avoid cases
// This should never happen because checks in GroupStateImpl should avoid cases
// where empty state would need to be written
throw new IllegalStateException("Attempting to write empty state")
}
Expand Down
Loading

0 comments on commit 82b598b

Please sign in to comment.