diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index e6e475bb82f82..4d50821620f5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -78,7 +78,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops */ - def apply(field: Any): Column = UnresolvedExtractValue(expr, Literal(field)) + def apply(extraction: Any): Column = UnresolvedExtractValue(expr, lit(extraction).expr) /** * Unary minus, i.e. negate the expression. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2ade955864b71..d58438e5d129c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -456,6 +456,7 @@ class DataFrameSuite extends QueryTest { assert(complexData.filter(complexData("a")(0) === 2).count() == 1) assert(complexData.filter(complexData("m")("1") === 1).count() == 1) assert(complexData.filter(complexData("s")("key") === 1).count() == 1) + assert(complexData.filter(complexData("m")(complexData("s")("value")) === 1).count() == 1) } test("SPARK-7324 dropDuplicates") {