Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed Mar 16, 2024
1 parent 94d6d2d commit e04b48c
Showing 1 changed file with 113 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,30 @@

package org.apache.pekko.stream.testkit

import java.io.PrintWriter
import java.io.StringWriter
import java.util.concurrent.CountDownLatch

import scala.annotation.tailrec
import scala.collection.immutable
import scala.concurrent.duration._
import scala.reflect.ClassTag

import org.apache.pekko
import pekko.actor.{ ActorRef, ActorSystem, DeadLetterSuppression, NoSerializationVerificationNeeded }
import pekko.actor.ClassicActorSystemProvider
import org.reactivestreams.{ Publisher, Subscriber, Subscription }
import pekko.actor.{
ActorRef,
ActorSystem,
ClassicActorSystemProvider,
DeadLetterSuppression,
NoSerializationVerificationNeeded
}
import pekko.japi._
import pekko.stream._
import pekko.stream.impl._
import pekko.testkit.{ TestActor, TestProbe }
import pekko.testkit.TestActor.AutoPilot
import pekko.testkit.{ TestActor, TestProbe }
import pekko.util.JavaDurationConverters._
import pekko.util.ccompat.JavaConverters._
import pekko.util.ccompat._

import org.reactivestreams.{ Publisher, Subscriber, Subscription }
import java.io.{ PrintWriter, StringWriter }
import java.util.concurrent.CountDownLatch
import scala.annotation.tailrec
import scala.collection.immutable
import scala.concurrent.duration._
import scala.reflect.ClassTag

/**
* Provides factory methods for various Publishers.
Expand Down Expand Up @@ -93,6 +97,13 @@ object TestPublisher {
*/
def apply[T](autoOnSubscribe: Boolean = true)(implicit system: ClassicActorSystemProvider): ManualProbe[T] =
new ManualProbe(autoOnSubscribe)(system.classicSystem)

/**
* JAVA API
* Probe that implements [[org.reactivestreams.Publisher]] interface.
*/
def create[T](autoOnSubscribe: Boolean, system: ClassicActorSystemProvider): ManualProbe[T] =
new ManualProbe(autoOnSubscribe)(system.classicSystem)
}

/**
Expand Down Expand Up @@ -136,6 +147,13 @@ object TestPublisher {
f
}

/**
* JAVA API
*/
def executeAfterSubscription[T](f: function.Creator[T]): T = {
executeAfterSubscription(f.create())
}

/**
* Expect a subscription.
*/
Expand Down Expand Up @@ -190,24 +208,38 @@ object TestPublisher {
}

/**
* JAVA API
* Expect no messages for a given duration.
*/
def expectNoMessage(max: java.time.Duration): Self = executeAfterSubscription {
probe.expectNoMessage(max.asScala)
self
def expectNoMessage(max: java.time.Duration): Self = {
expectNoMessage(max.asScala)
}

/**
* Receive messages for a given duration or until one does not match a given partial function.
*/
def receiveWhile[T](
max: Duration = Duration.Undefined,
def receiveWhile[T](max: Duration = Duration.Undefined,
idle: Duration = Duration.Inf,
messages: Int = Int.MaxValue)(f: PartialFunction[PublisherEvent, T]): immutable.Seq[T] =
executeAfterSubscription { probe.receiveWhile(max, idle, messages)(f.asInstanceOf[PartialFunction[AnyRef, T]]) }
executeAfterSubscription {
probe.receiveWhile(max, idle, messages)(f.asInstanceOf[PartialFunction[AnyRef, T]])
}

/**
* JAVA API
* Receive messages for a given duration or until one does not match a given partial function.
*/
def receiveWhile[T](max: java.time.Duration,
idle: java.time.Duration,
messages: Int,
f: PartialFunction[PublisherEvent, T]): java.util.List[T] = {
receiveWhile(max.asScala, idle.asScala, messages)(f).asJava
}

def expectEventPF[T](f: PartialFunction[PublisherEvent, T]): T =
executeAfterSubscription { probe.expectMsgPF[T]()(f.asInstanceOf[PartialFunction[Any, T]]) }
executeAfterSubscription {
probe.expectMsgPF[T]()(f.asInstanceOf[PartialFunction[Any, T]])
}

def getPublisher: Publisher[I] = this

Expand All @@ -232,6 +264,7 @@ object TestPublisher {
}

/**
* JAVA API
* Execute code block while bounding its execution time between `min` and
* `max`. `within` blocks may be nested. All methods in this trait which
* take maximum wait times are available in a version which implicitly uses
Expand All @@ -247,24 +280,35 @@ object TestPublisher {
* }
* }}}
*/
def within[T](min: java.time.Duration, max: java.time.Duration)(f: => T): T = executeAfterSubscription {
probe.within(min.asScala, max.asScala)(f)
}
def within[T](min: java.time.Duration,
max: java.time.Duration,
creator: function.Creator[T]): T =
within(min.asScala, max.asScala)(creator.create())

/**
* Same as calling `within(0 seconds, max)(f)`.
*/
def within[T](max: FiniteDuration)(f: => T): T = executeAfterSubscription { probe.within(max)(f) }
def within[T](max: FiniteDuration)(f: => T): T = executeAfterSubscription {
probe.within(max)(f)
}

/**
* JAVA API
* Same as calling `within(Duration.ofSeconds(0), max)(f)`.
*/
def within[T](max: java.time.Duration)(f: => T): T = executeAfterSubscription { probe.within(max.asScala)(f) }
def within[T](max: java.time.Duration,
creator: function.Creator[T]): T = within(max.asScala)(creator.create())
}

object Probe {
def apply[T](initialPendingRequests: Long = 0)(implicit system: ClassicActorSystemProvider): Probe[T] =
new Probe(initialPendingRequests)(system.classicSystem)

/**
* JAVA API
*/
def create[T](initialPendingRequests: Long, system: ClassicActorSystemProvider): Probe[T] =
apply(initialPendingRequests)(system.classicSystem)
}

/**
Expand Down Expand Up @@ -324,6 +368,7 @@ object TestPublisher {
assert(cause == expectedCause, s"Expected cancellation cause to be $expectedCause but was $cause")
this
}

def expectCancellationWithCause[E <: Throwable: ClassTag](): E = subscription.expectCancellation() match {
case e: E => e
case cause =>
Expand Down Expand Up @@ -367,6 +412,12 @@ object TestSubscriber {
object ManualProbe {
def apply[T]()(implicit system: ClassicActorSystemProvider): ManualProbe[T] =
new ManualProbe()(system.classicSystem)

/**
* JAVA API
*/
def create[T]()(system: ClassicActorSystemProvider): ManualProbe[T] =
apply()(system.classicSystem)
}

/**
Expand Down Expand Up @@ -406,10 +457,10 @@ object TestSubscriber {
probe.expectMsgType[SubscriberEvent](max)

/**
* JAVA API
* Expect and return [[SubscriberEvent]] (any of: `OnSubscribe`, `OnNext`, `OnError` or `OnComplete`).
*/
def expectEvent(max: java.time.Duration): SubscriberEvent =
probe.expectMsgType[SubscriberEvent](max.asScala)
def expectEvent(max: java.time.Duration): SubscriberEvent = expectEvent(max.asScala)

/**
* Fluent DSL
Expand Down Expand Up @@ -441,10 +492,11 @@ object TestSubscriber {
}

/**
* JAVA API
*
* Expect and return a stream element during specified time or timeout.
*/
def expectNext(d: java.time.Duration): I =
expectNext(d.asScala)
def expectNext(d: java.time.Duration): I = expectNext(d.asScala)

/**
* Fluent DSL
Expand All @@ -467,14 +519,13 @@ object TestSubscriber {
}

/**
* JAVA PAI
*
* Fluent DSL
*
* Expect a stream element during specified time or timeout.
*/
def expectNext(d: java.time.Duration, element: I): Self = {
probe.expectMsg(d.asScala, OnNext(element))
self
}
def expectNext(d: java.time.Duration, element: I): Self = expectNext(d.asScala, element)

/**
* Fluent DSL
Expand Down Expand Up @@ -546,6 +597,14 @@ object TestSubscriber {
self
}

/**
* JAVA API
*
* Fluent DSL
* Expect the given elements to be signalled in any order.
*/
def expectNextUnorderedN(all: java.util.List[I]): Self = expectNextUnorderedN(Util.immutableSeq(all))

/**
* Fluent DSL
*
Expand Down Expand Up @@ -825,10 +884,13 @@ object TestSubscriber {
.flatten

/**
* JAVA API
*
* Drains a given number of messages
*/
def receiveWithin(max: java.time.Duration, messages: Int = Int.MaxValue): immutable.Seq[I] =
receiveWithin(max.asScala, messages)
def receiveWithin(max: java.time.Duration, messages: Int): java.util.List[I] = {
receiveWithin(max.asScala, messages).asJava
}

/**
* Attempt to drain the stream into a strict collection (by requesting `Long.MaxValue` elements).
Expand Down Expand Up @@ -887,6 +949,8 @@ object TestSubscriber {
def within[T](min: FiniteDuration, max: FiniteDuration)(f: => T): T = probe.within(min, max)(f)

/**
* JAVA API
*
* Execute code block while bounding its execution time between `min` and
* `max`. `within` blocks may be nested. All methods in this trait which
* take maximum wait times are available in a version which implicitly uses
Expand All @@ -902,18 +966,21 @@ object TestSubscriber {
* }
* }}}
*/
def within[T](min: java.time.Duration, max: java.time.Duration)(f: pekko.japi.function.Function[Unit, T]): T =
probe.within(min.asScala, max.asScala)(f.apply())
def within[T](min: java.time.Duration,
max: java.time.Duration,
creator: function.Creator[T]): T = within(min.asScala, max.asScala)(creator.create())

/**
* Same as calling `within(0 seconds, max)(f)`.
*/
def within[T](max: FiniteDuration)(f: => T): T = probe.within(max)(f)

/**
* JAVA API
*
* Same as calling `within(Duration.ofSeconds(0), max)(f)`.
*/
def within[T](max: java.time.Duration)(f: => T): T = probe.within(max.asScala)(f)
def within[T](max: java.time.Duration)(creator: function.Creator[T]): T = within(max.asScala)(creator.create())

def onSubscribe(subscription: Subscription): Unit = probe.ref ! OnSubscribe(subscription)
def onNext(element: I): Unit = probe.ref ! OnNext(element)
Expand All @@ -923,6 +990,11 @@ object TestSubscriber {

object Probe {
def apply[T]()(implicit system: ClassicActorSystemProvider): Probe[T] = new Probe()(system.classicSystem)

/**
* JAVA API
*/
def create[T]()(implicit system: ClassicActorSystemProvider): Probe[T] = apply()(system)
}

/**
Expand Down Expand Up @@ -987,10 +1059,7 @@ object TestSubscriber {
/**
* Request and expect a stream element during the specified time or timeout.
*/
def requestNext(d: java.time.Duration): T = {
subscription.request(1)
expectNext(d)
}
def requestNext(d: java.time.Duration): T = requestNext(d.asScala)
}
}

Expand Down Expand Up @@ -1035,7 +1104,7 @@ private[stream] object StreamTestKit {

final class ProbeSource[T](val attributes: Attributes, shape: SourceShape[T])(implicit system: ActorSystem)
extends SourceModule[T, TestPublisher.Probe[T]](shape) {
override def create(context: MaterializationContext) = {
override def create(context: MaterializationContext): (TestPublisher.Probe[T], TestPublisher.Probe[T]) = {
val probe = TestPublisher.probe[T]()
(probe, probe)
}
Expand All @@ -1045,7 +1114,7 @@ private[stream] object StreamTestKit {

final class ProbeSink[T](val attributes: Attributes, shape: SinkShape[T])(implicit system: ActorSystem)
extends SinkModule[T, TestSubscriber.Probe[T]](shape) {
override def create(context: MaterializationContext) = {
override def create(context: MaterializationContext): (TestSubscriber.Probe[T], TestSubscriber.Probe[T]) = {
val probe = TestSubscriber.probe[T]()
(probe, probe)
}
Expand Down

0 comments on commit e04b48c

Please sign in to comment.