Skip to content

Commit

Permalink
[SPARK-16839][SQL] Simplify Struct creation code path
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Simplify struct creation, especially the aspect of `CleanupAliases` which missed some aliases when handling trees created by `CreateStruct`.

This PR includes:

1. A failing test (create struct with nested aliases, some of the aliases survive `CleanupAliases`).
2. A fix that transforms `CreateStruct` into a `CreateNamedStruct` constructor, effectively eliminating `CreateStruct` from all expression trees.
3. A `NamePlaceHolder` used by `CreateStruct` when column names cannot be extracted from unresolved `NamedExpression`.
4. A new Analyzer rule that resolves `NamePlaceHolder` into a string literal once the `NamedExpression` is resolved.
5. `CleanupAliases` code was simplified as it no longer has to deal with `CreateStruct`'s top level columns.

## How was this patch tested?
Running all tests-suits in package org.apache.spark.sql, especially including the analysis suite, making sure added test initially fails, after applying suggested fix rerun the entire analysis package successfully.

Modified few tests that expected `CreateStruct` which is now transformed into `CreateNamedStruct`.

Author: eyal farago <eyal farago>
Author: Herman van Hovell <[email protected]>
Author: eyal farago <[email protected]>
Author: Eyal Farago <[email protected]>
Author: Hyukjin Kwon <[email protected]>
Author: eyalfa <[email protected]>

Closes #15718 from hvanhovell/SPARK-16839-2.
  • Loading branch information
eyal farago authored and hvanhovell committed Nov 2, 2016
1 parent 9c8deef commit f151bd1
Show file tree
Hide file tree
Showing 14 changed files with 169 additions and 198 deletions.
12 changes: 6 additions & 6 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1222,16 +1222,16 @@ test_that("column functions", {
# Test struct()
df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)),
schema = c("a", "b", "c"))
result <- collect(select(df, struct("a", "c")))
result <- collect(select(df, alias(struct("a", "c"), "d")))
expected <- data.frame(row.names = 1:2)
expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)),
listToStruct(list(a = 4L, c = 6L)))
expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)),
listToStruct(list(a = 4L, c = 6L)))
expect_equal(result, expected)

result <- collect(select(df, struct(df$a, df$b)))
result <- collect(select(df, alias(struct(df$a, df$b), "d")))
expected <- data.frame(row.names = 1:2)
expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)),
listToStruct(list(a = 4L, b = 5L)))
expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)),
listToStruct(list(a = 4L, b = 5L)))
expect_equal(result, expected)

# Test encode(), decode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.{TreeNodeRef}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -83,6 +83,7 @@ class Analyzer(
ResolveTableValuedFunctions ::
ResolveRelations ::
ResolveReferences ::
ResolveCreateNamedStruct ::
ResolveDeserializer ::
ResolveNewInstance ::
ResolveUpCast ::
Expand Down Expand Up @@ -653,11 +654,12 @@ class Analyzer(
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateStruct if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateNamedStruct if containsStar(c.valExprs) =>
val newChildren = c.children.grouped(2).flatMap {
case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children
case kv => kv
}
c.copy(children = newChildren.toList )
case c: CreateArray if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
Expand Down Expand Up @@ -1141,7 +1143,7 @@ class Analyzer(
case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved =>
// Get the left hand side expressions.
val expressions = e match {
case CreateStruct(exprs) => exprs
case cns : CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) =>
Expand Down Expand Up @@ -2072,18 +2074,8 @@ object EliminateUnions extends Rule[LogicalPlan] {
*/
object CleanupAliases extends Rule[LogicalPlan] {
private def trimAliases(e: Expression): Expression = {
var stop = false
e.transformDown {
// CreateStruct is a special case, we need to retain its top level Aliases as they decide the
// name of StructField. We also need to stop transform down this expression, or the Aliases
// under CreateStruct will be mistakenly trimmed.
case c: CreateStruct if !stop =>
stop = true
c.copy(children = c.children.map(trimNonTopLevelAliases))
case c: CreateStructUnsafe if !stop =>
stop = true
c.copy(children = c.children.map(trimNonTopLevelAliases))
case Alias(child, _) if !stop => child
case Alias(child, _) => child
}
}

Expand Down Expand Up @@ -2116,15 +2108,8 @@ object CleanupAliases extends Rule[LogicalPlan] {
case a: AppendColumns => a

case other =>
var stop = false
other transformExpressionsDown {
case c: CreateStruct if !stop =>
stop = true
c.copy(children = c.children.map(trimNonTopLevelAliases))
case c: CreateStructUnsafe if !stop =>
stop = true
c.copy(children = c.children.map(trimNonTopLevelAliases))
case Alias(child, _) if !stop => child
case Alias(child, _) => child
}
}
}
Expand Down Expand Up @@ -2217,3 +2202,19 @@ object TimeWindowing extends Rule[LogicalPlan] {
}
}
}

/**
* Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s.
*/
object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
case e: CreateNamedStruct if !e.resolved =>
val children = e.children.grouped(2).flatMap {
case Seq(NamePlaceholder, e: NamedExpression) if e.resolved =>
Seq(Literal(e.name), e)
case kv =>
kv
}
CreateNamedStruct(children.toList)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
expression[CreateStruct]("struct"),
CreateStruct.registryEntry,

// misc functions
expression[AssertTrue]("assert_true"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ object UnsafeProjection {
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
val unsafeExprs = exprs.map(_ transform {
case CreateStruct(children) => CreateStructUnsafe(children)
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
GenerateUnsafeProjection.generate(unsafeExprs)
Expand All @@ -145,7 +144,6 @@ object UnsafeProjection {
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
val e = exprs.map(BindReferences.bindReference(_, inputSchema))
.map(_ transform {
case CreateStruct(children) => CreateStructUnsafe(children)
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
Expand Down
Loading

0 comments on commit f151bd1

Please sign in to comment.