Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32282][SQL] Improve EnsureRquirement.reorderJoinKeys to handle more scenarios such as PartitioningCollection #29074

Closed
wants to merge 13 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,14 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
leftKeys: IndexedSeq[Expression],
rightKeys: IndexedSeq[Expression],
expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
currentOrderOfKeys: Seq[Expression]): Option[(Seq[Expression], Seq[Expression])] = {
if (expectedOrderOfKeys.size != currentOrderOfKeys.size) {
return (leftKeys, rightKeys)
return None
}

// Check if the current order already satisfies the expected order.
if (expectedOrderOfKeys.zip(currentOrderOfKeys).forall(p => p._1.semanticEquals(p._2))) {
return Some(leftKeys, rightKeys)
}

// Build a lookup between an expression and the positions its holds in the current key seq.
Expand All @@ -159,10 +164,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
rightKeysBuffer += rightKeys(index)
case _ =>
// The expression cannot be found, or we have exhausted all indices for that expression.
return (leftKeys, rightKeys)
return None
}
}
(leftKeysBuffer.toSeq, rightKeysBuffer.toSeq)
Some(leftKeysBuffer.toSeq, rightKeysBuffer.toSeq)
}

private def reorderJoinKeys(
Expand All @@ -171,19 +176,50 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
(leftPartitioning, rightPartitioning) match {
case (HashPartitioning(leftExpressions, _), _) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
case (_, HashPartitioning(rightExpressions, _)) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
case _ =>
(leftKeys, rightKeys)
}
reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, rightPartitioning)
.getOrElse((leftKeys, rightKeys))
} else {
(leftKeys, rightKeys)
}
}

/**
* Recursively reorders the join keys based on partitioning. It starts reordering the
* join keys to match HashPartitioning on either side, followed by PartitioningCollection.
*/
private def reorderJoinKeysRecursively(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): Option[(Seq[Expression], Seq[Expression])] = {
(leftPartitioning, rightPartitioning) match {
case (HashPartitioning(leftExpressions, _), _) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, UnknownPartitioning(0), rightPartitioning))
imback82 marked this conversation as resolved.
Show resolved Hide resolved
case (_, HashPartitioning(rightExpressions, _)) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
.orElse(reorderJoinKeysRecursively(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be also implemented by looking at left partitioning first then move to the right partitionoing:

    (leftPartitioning, rightPartitioning) match {
      case (HashPartitioning(leftExpressions, _), _) =>
        reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
          .orElse(reorderJoinKeysRecursively(
            leftKeys, rightKeys, UnknownPartitioning(0), rightPartitioning))
      case (PartitioningCollection(partitionings), _) =>
        partitionings.foreach { p =>
          reorderJoinKeysRecursively(leftKeys, rightKeys, p, rightPartitioning).map { k =>
            return Some(k)
          }
        }
        reorderJoinKeysRecursively(leftKeys, rightKeys, UnknownPartitioning(0), rightPartitioning)
      case (_, HashPartitioning(rightExpressions, _)) =>
        reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
      case (_, PartitioningCollection(partitionings)) =>
        partitionings.foreach { p =>
          reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, p).map { k =>
            return Some(k)
          }
        }
        None
      case _ =>
        None
    }

However, I chose this way so that the behavior remains the same. If you have leftPartitioning = PartitioningCollection and rightPartitioning = HashPartitioning, it will match the rightPartitioning first, which is the existing behavior.

leftKeys, rightKeys, leftPartitioning, UnknownPartitioning(0)))
case (PartitioningCollection(partitionings), _) =>
partitionings.foreach { p =>
reorderJoinKeysRecursively(leftKeys, rightKeys, p, rightPartitioning).map { k =>
return Some(k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

partitionings.foldLeft(None) { (res, p) =>
  res.orElse(reorderJoinKeysRecursively...)
}.getOrElse(reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated.

}
}
reorderJoinKeysRecursively(leftKeys, rightKeys, UnknownPartitioning(0), rightPartitioning)
case (_, PartitioningCollection(partitionings)) =>
partitionings.foreach { p =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you do the same refactor here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, p).map { k =>
return Some(k)
}
}
None
case _ =>
None
}
}

/**
* When the physical operators are created for JOIN, the ordering of join keys is based on order
* in which the join keys appear in the user query. That might not match with the output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,88 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
}
}

test("EnsureRequirements.reorder should fallback to the right side HashPartitioning") {
imback82 marked this conversation as resolved.
Show resolved Hide resolved
val plan1 = DummySparkPlan(
outputPartitioning = HashPartitioning(exprA :: exprB :: exprC :: Nil, 5))
val plan2 = DummySparkPlan(
outputPartitioning = HashPartitioning(exprB :: exprC :: Nil, 5))
// The left keys cannot be reordered to match the left partitioning, and it should
// fall back to reorder the right side.
imback82 marked this conversation as resolved.
Show resolved Hide resolved
val smjExec = SortMergeJoinExec(
exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2)
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec)
outputPlan match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _),
SortExec(_, _,
DummySparkPlan(_, _, HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) =>
assert(leftKeys !== smjExec.leftKeys)
assert(rightKeys !== smjExec.rightKeys)
assert(leftKeys === leftPartitioningExpressions)
assert(rightKeys === rightPartitioningExpressions)
case _ => fail(outputPlan.toString)
imback82 marked this conversation as resolved.
Show resolved Hide resolved
}
}

test("EnsureRequirements.reorder should handle PartitioningCollection") {
// PartitioningCollection on the left side of join.
val plan1 = DummySparkPlan(
outputPartitioning = PartitioningCollection(Seq(
HashPartitioning(exprA :: exprB :: Nil, 5),
HashPartitioning(exprA :: Nil, 5))))
val plan2 = DummySparkPlan()
val smjExec1 = SortMergeJoinExec(
exprB :: exprA :: Nil, exprA :: exprB :: Nil, Inner, None, plan1, plan2)
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec1)
outputPlan match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _,
DummySparkPlan(_, _, PartitioningCollection(leftPartitionings), _, _), _),
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) =>
assert(leftKeys !== smjExec1.leftKeys)
assert(rightKeys !== smjExec1.rightKeys)
assert(leftKeys === leftPartitionings(0).asInstanceOf[HashPartitioning].expressions)
assert(rightKeys === rightPartitioningExpressions)
case _ => fail(outputPlan.toString)
}

// PartitioningCollection on the right side of join.
val smjExec2 = SortMergeJoinExec(
exprA :: exprB :: Nil, exprB :: exprA :: Nil, Inner, None, plan2, plan1)
val outputPlan2 = EnsureRequirements(spark.sessionState.conf).apply(smjExec2)
outputPlan2 match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _),
SortExec(_, _,
DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) =>
assert(leftKeys !== smjExec2.leftKeys)
assert(rightKeys !== smjExec2.rightKeys)
assert(leftKeys === leftPartitioningExpressions)
assert(rightKeys === rightPartitionings(0).asInstanceOf[HashPartitioning].expressions)
case _ => fail(outputPlan2.toString)
}

// Both sides are PartitioningCollection and falls back to the right side.
val smjExec3 = SortMergeJoinExec(
exprA :: exprC :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan1)
val outputPlan3 = EnsureRequirements(spark.sessionState.conf).apply(smjExec2)
imback82 marked this conversation as resolved.
Show resolved Hide resolved
outputPlan3 match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _),
SortExec(_, _,
DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) =>
assert(leftKeys !== smjExec2.leftKeys)
assert(rightKeys !== smjExec2.rightKeys)
assert(leftKeys === leftPartitioningExpressions)
assert(rightKeys === rightPartitionings(0).asInstanceOf[HashPartitioning].expressions)
case _ => fail(outputPlan3.toString)
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down