Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33848][SQL][FOLLOWUP] Introduce allowList for push into (if / case) branches #30955

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -548,41 +548,68 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
foldables.nonEmpty && others.length < 2
}

// Not all UnaryExpression can be pushed into (if / case) branches, e.g. Alias.
private def supportedUnaryExpression(e: UnaryExpression): Boolean = e match {
case _: IsNull | _: IsNotNull => true
case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true
case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length =>
true
case _: CastBase => true
case _: GetDateField | _: LastDay => true
case _: ExtractIntervalPart => true
case _: ArraySetLike => true
case _: ExtractValue => true
case _ => false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's include ExtractValue as well, which is common with nested fields.

}

// Not all BinaryExpression can be pushed into (if / case) branches.
private def supportedBinaryExpression(e: BinaryExpression): Boolean = e match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add comments as well.

case _: BinaryComparison | _: StringPredicate | _: StringRegexExpression => true
case _: BinaryArithmetic => true
case _: BinaryMathExpression => true
case _: AddMonths | _: DateAdd | _: DateAddInterval | _: DateDiff | _: DateSub => true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What the property should the expression have to be here? For example, can I add DateAddYMInterval, TimestampAddYMInterval and TimeAdd?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I opened the JIRA for that: SPARK-34841

case _: FindInSet | _: RoundBase => true
case _ => false
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case a: Alias => a // Skip an alias.
case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
if atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = u.withNewChildren(Array(trueValue)),
falseValue = u.withNewChildren(Array(falseValue)))

case u @ UnaryExpression(c @ CaseWhen(branches, elseValue))
if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
if supportedUnaryExpression(u) && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2)))),
elseValue.map(e => u.withNewChildren(Array(e))))

case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right)
if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
if supportedBinaryExpression(b) && right.foldable &&
atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = b.withNewChildren(Array(trueValue, right)),
falseValue = b.withNewChildren(Array(falseValue, right)))

case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue))
if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
if supportedBinaryExpression(b) && left.foldable &&
atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = b.withNewChildren(Array(left, trueValue)),
falseValue = b.withNewChildren(Array(left, falseValue)))

case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right)
if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
if supportedBinaryExpression(b) && right.foldable &&
atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right)))),
elseValue.map(e => b.withNewChildren(Array(e, right))))

case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue))
if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
if supportedBinaryExpression(b) && left.foldable &&
atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2)))),
elseValue.map(e => b.withNewChildren(Array(left, e))))
Expand Down