diff --git a/src/main/scala/io/getclump/ClumpContext.scala b/src/main/scala/io/getclump/ClumpContext.scala index 85d7bb3..f8770e6 100644 --- a/src/main/scala/io/getclump/ClumpContext.scala +++ b/src/main/scala/io/getclump/ClumpContext.scala @@ -1,31 +1,49 @@ package io.getclump +import scala.collection.immutable.SortedMap import scala.collection.mutable import scala.concurrent.ExecutionContext private[getclump] final class ClumpContext { - private[this] val fetchers = new mutable.HashMap[ClumpSource[_, _], ClumpFetcher[_, _]]() + private[this] val fetchers = mutable.HashMap.empty[ClumpSource[_, _], ClumpFetcher[_, _]] def flush(clumps: List[Clump[_]])(implicit ec: ExecutionContext): Future[Unit] = { - // 1. Get a list of all visible clumps grouped by level of composition, starting at the highest level - val upstreamByLevel = getClumpsByLevel(clumps) + // 1. Get a list of all visible clumps + val upstream = getAllUpstream(clumps) // 2. Flush the fetches from all the visible clumps - flushFetchesInParallel(upstreamByLevel.flatten).flatMap { _ => + flushFetchesInParallel(upstream).flatMap { _ => // 3. Walk through the downstream clumps as well, starting at the deepest level - flushDownstreamByLevel(upstreamByLevel.reverse) + flushDownstreamByLevel(groupClumpsByLevel(upstream)) } } - // Unfold all visible (ie. upstream) clumps from lowest to highest level - private[this] def getClumpsByLevel(clumps: List[Clump[_]]): List[List[Clump[_]]] = { + // Unfold all visible (ie. upstream) clumps + private[this] def getAllUpstream(clumps: List[Clump[_]]): List[Clump[_]] = { clumps match { case Nil => Nil - case _ => clumps :: getClumpsByLevel(clumps.flatMap(_.upstream)) + case _ => clumps ::: getAllUpstream(clumps.flatMap(_.upstream)) } } + // Strip the leaves at the bottom of the clump tree one level at a time so that these two conditions are satisfied: + // - Clumps appear in later lists than all their upstream children + // - Clumps appear as early in the list as possible + private[this] def groupClumpsByLevel(clumps: List[Clump[_]]): List[List[Clump[_]]] = { + // 1. Get the longest distance from this Clump to the bottom of the tree (memoized function) + val m = mutable.HashMap.empty[Clump[_], Int] + def getDistanceFromBottom(clump: Clump[_]): Int = m.getOrElseUpdate(clump, { + clump.upstream match { + case Nil => 0 + case list => list.map(getDistanceFromBottom).max + 1 + } + }) + + // 2. Group clumps by these levels and return the deepest level first + SortedMap(clumps.groupBy(getDistanceFromBottom).toSeq:_*).values.toList + } + private[this] def flushDownstreamByLevel(levels: List[List[Clump[_]]])(implicit ec: ExecutionContext): Future[Unit] = { levels match { case Nil => Future.successful(()) diff --git a/src/main/scala/io/getclump/ClumpFetcher.scala b/src/main/scala/io/getclump/ClumpFetcher.scala index b6b4f84..b5ac025 100644 --- a/src/main/scala/io/getclump/ClumpFetcher.scala +++ b/src/main/scala/io/getclump/ClumpFetcher.scala @@ -5,7 +5,7 @@ import scala.concurrent.ExecutionContext private[getclump] final class ClumpFetcher[T, U](source: ClumpSource[T, U]) { - private[this] val fetches = mutable.LinkedHashMap[T, Promise[Option[U]]]() + private[this] val fetches = mutable.LinkedHashMap.empty[T, Promise[Option[U]]] def get(input: T): Future[Option[U]] = synchronized { diff --git a/src/test/scala/io/getclump/ClumpExecutionSpec.scala b/src/test/scala/io/getclump/ClumpExecutionSpec.scala index d43d526..bce90cf 100644 --- a/src/test/scala/io/getclump/ClumpExecutionSpec.scala +++ b/src/test/scala/io/getclump/ClumpExecutionSpec.scala @@ -48,9 +48,22 @@ class ClumpExecutionSpec extends Spec { source2Fetches mustEqual List(Set(3, 4)) } + // Implementation note: this test will fail if ClumpContext::getClumpsByLevel does not satisfy the requirement that + // "Clumps appear in later lists than all their upstream children" "for clumps created inside nested flatmaps" in new Context { - val clump1 = Clump.value(1).flatMap(source1.get(_)).flatMap(source2.get(_)) - val clump2 = Clump.value(2).flatMap(source1.get(_)).flatMap(source2.get(_)) + val clump1 = Clump.value(1).flatMap(source1.get).flatMap(source2.get) + val clump2 = Clump.value(2).flatMap(source1.get).flatMap(source2.get) + + clumpResult(Clump.collect(clump1, clump2)) mustEqual Some(List(100, 200)) + source1Fetches mustEqual List(Set(1, 2)) + source2Fetches mustEqual List(Set(20, 10)) + } + + // Implementation note: this test will fail if ClumpContext::getClumpsByLevel does not satisfy the requirement that + // "Clumps appear as early in the list as possible" + "for clumps created inside nested flatmaps at different levels of composition" in new Context { + val clump1 = Clump.value(1).flatMap(source1.get).flatMap(source2.get).map(identity) + val clump2 = Clump.value(2).flatMap(source1.get).flatMap(source2.get) clumpResult(Clump.collect(clump1, clump2)) mustEqual Some(List(100, 200)) source1Fetches mustEqual List(Set(1, 2))