Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed May 17, 2024
1 parent 57948c8 commit 815cfd1
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}

def writeArbitraryInfos(): Unit = {
children.foreach(_.writeArbitraryInfos())
}

val id: Int = SparkPlan.newPlanId()

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ abstract class ProgressContext(
currentTriggerStartOffsets != null && currentTriggerEndOffsets != null &&
currentTriggerLatestOffsets != null
)
lastExecution.executedPlan.writeArbitraryInfos()
currentTriggerEndTimestamp = triggerClock.getTimeMillis()
val processingTimeMills = currentTriggerEndTimestamp - currentTriggerStartTimestamp
assert(lastExecution != null, "executed batch should provide the information for execution.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class StatefulProcessorHandleImpl(
timeMode: TimeMode,
isStreaming: Boolean = true,
batchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric] = Map.empty)
metrics: Map[String, SQLMetric] = Map.empty,
arbInfos: Map[String, ArbitraryInfo] = Map.empty)
extends StatefulProcessorHandle with Logging {
import StatefulProcessorHandleState._

Expand Down Expand Up @@ -130,6 +131,7 @@ class StatefulProcessorHandleImpl(
valEncoder: Encoder[T]): ValueState[T] = {
verifyStateVarOperations("get_value_state")
incrementMetric("numValueStateVars")
arbInfos.get("arbInfo1").foreach(_.add("valueState"))
val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder)
resultState
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.streaming
import java.util.UUID
import java.util.concurrent.TimeUnit.NANOSECONDS

import org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand All @@ -32,8 +34,64 @@ import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSch
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils}
import org.apache.spark.util.{AccumulatorV2, CompletionIterator, SerializableConfiguration, Utils}

class ColFamilyMetadata
class ArbitraryInfo(value: String) extends AccumulatorV2[String, String] {

private var _value: String = value
/**
* Returns if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero
* value; for a list accumulator, Nil is zero value.
*/
override def isZero: Boolean = {
_value.isEmpty
}

/**
* Creates a new copy of this accumulator.
*/
override def copy(): AccumulatorV2[String, String] = {
new ArbitraryInfo(_value)
}

/**
* Resets this accumulator, which is zero value. i.e. call `isZero` must
* return true.
*/
override def reset(): Unit = {
_value = ""
}

/**
* Takes the inputs and accumulates.
*/
override def add(v: String): Unit = {
_value = v
}

/**
* Merges another same-type accumulator into this one and update its state, i.e. this should be
* merge-in-place.
*/
override def merge(other: AccumulatorV2[String, String]): Unit = {
_value = other.value
}

/**
* Defines the current value of this accumulator
*/
override def value: String = _value
}

object ArbitraryInfo {

def create(sc: SparkContext, key: String, value: String): ArbitraryInfo = {
val arbInfo = new ArbitraryInfo(value)
arbInfo.register(sc, Some(key))
arbInfo
}
}
/**
* Physical operator for executing `TransformWithState`
*
Expand Down Expand Up @@ -73,10 +131,19 @@ case class TransformWithStateExec(
initialStateDataAttrs: Seq[Attribute],
initialStateDeserializer: Expression,
initialState: SparkPlan)
extends BinaryExecNode with StateStoreWriter with WatermarkSupport with ObjectProducerExec {
extends BinaryExecNode
with StateStoreWriter
with WatermarkSupport
with ObjectProducerExec
with Logging {

override def shortName: String = "transformWithStateExec"

val arbInfos: Map[String, ArbitraryInfo] = Map(
"arbInfo1" -> ArbitraryInfo.create(sparkContext, "key1", "value1"),
"arbInfo2" -> ArbitraryInfo.create(sparkContext, "key2", "value2")
)

override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
if (timeMode == ProcessingTime) {
// TODO: check if we can return true only if actual timers are registered, or there is
Expand Down Expand Up @@ -338,6 +405,13 @@ case class TransformWithStateExec(
)
}

override def writeArbitraryInfos(): Unit = {
arbInfos.foreach(kv => {
val arbInfo = kv._2
logError(s"### arbInfo: ${kv._1} ${arbInfo.value}")
})
}

override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver

Expand Down Expand Up @@ -457,7 +531,7 @@ case class TransformWithStateExec(
CompletionIterator[InternalRow, Iterator[InternalRow]] = {
val processorHandle = new StatefulProcessorHandleImpl(
store, getStateInfo.queryRunId, keyEncoder, timeMode,
isStreaming, batchTimestampMs, metrics)
isStreaming, batchTimestampMs, metrics, arbInfos)
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
statefulProcessor.setHandle(processorHandle)
statefulProcessor.init(outputMode, timeMode)
Expand All @@ -471,7 +545,7 @@ case class TransformWithStateExec(
initStateIterator: Iterator[InternalRow]):
CompletionIterator[InternalRow, Iterator[InternalRow]] = {
val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId,
keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics)
keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics, arbInfos)
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
statefulProcessor.setHandle(processorHandle)
statefulProcessor.init(outputMode, timeMode)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.apache.spark.sql.streaming

import org.apache.spark.LocalSparkContext
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider}
import org.apache.spark.sql.internal.SQLConf
class TransformWithStateReaderSuite extends StateStoreMetricsTest
with AlsoTestWithChangelogCheckpointingEnabled with LocalSparkContext {

import testImplicits._

test("test accumulator") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
val inputData = MemoryStream[String]
val result = inputData.toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessor(),
TimeMode.None(),
OutputMode.Update())

testStream(result, OutputMode.Update())(
AddData(inputData, "a"),
CheckNewAnswer(("a", "1")),
Execute { q =>
assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0)
assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0)
},
AddData(inputData, "a", "b"),
CheckNewAnswer(("a", "2"), ("b", "1")),
StopStream,
StartStream(),
AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a
CheckNewAnswer(("b", "2")),
StopStream,
StartStream(),
AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and
CheckNewAnswer(("a", "1"), ("c", "1"))
)
}
}
}

0 comments on commit 815cfd1

Please sign in to comment.