diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fcc0c56db9c94..bc4e3c9d273ab 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1275,7 +1275,7 @@ def __init__(self, jc): # container operators __contains__ = _bin_op("contains") - __getitem__ = _bin_op("getItem") + __getitem__ = _bin_op("apply") # bitwise operators bitwiseOR = _bin_op("bitwiseOR") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 616f4b26ae098..5bcaefeea39b3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -941,6 +941,31 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(resolveGetField('c.struct(typeS).at(2).getField("a")), "aa", row) } + test("error message of GetField") { + val structType = StructType(StructField("a", StringType, true) :: Nil) + val arrayStructType = ArrayType(structType) + val arrayType = ArrayType(StringType) + val otherType = StringType + + def checkErrorMessage( + childDataType: DataType, + fieldDataType: DataType, + errorMesage: String): Unit = { + val e = intercept[org.apache.spark.sql.AnalysisException] { + GetField( + Literal.create(null, childDataType), + Literal.create(null, fieldDataType), + _ == _) + } + assert(e.getMessage().contains(errorMesage)) + } + + checkErrorMessage(structType, IntegerType, "Field name should be String Literal") + checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal") + checkErrorMessage(arrayType, StringType, "Array index should be integral type") + checkErrorMessage(otherType, StringType, "Can't get field on") + } + test("arithmetic") { val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 7acc06b15f4ef..5cbe2208734c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -67,6 +67,16 @@ class Column(protected[sql] val expr: Expression) extends Logging { override def hashCode: Int = this.expr.hashCode + /** + * An expression that gets an item at a position out of an [[ArrayType]], + * or gets a value by key in a [[MapType]], + * or gets a field by name in a [[StructType]], + * or gets an array of fields by name in an array of [[StructType]]. + * + * @group expr_ops + */ + def apply(field: Any): Column = UnresolvedGetField(expr, Literal(field)) + /** * Unary minus, i.e. negate the expression. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1515e9b843771..d2ca8dccae574 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -449,7 +449,7 @@ class DataFrameSuite extends QueryTest { testData.collect().map { case Row(key: Int, value: String) => Row(key, value, key + 1) }.toSeq) - assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol")) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol")) } test("replace column using withColumn") { @@ -484,7 +484,7 @@ class DataFrameSuite extends QueryTest { testData.collect().map { case Row(key: Int, value: String) => Row(key, value, key + 1) }.toSeq) - assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) + assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } test("randomSplit") { @@ -593,4 +593,10 @@ class DataFrameSuite extends QueryTest { Row(new java.math.BigDecimal(2.0))) TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } + + test("SPARK-7133: Implement struct, array, and map field accessor") { + assert(complexData.filter(complexData("a")(0) === 2).count() == 1) + assert(complexData.filter(complexData("m")("1") === 1).count() == 1) + assert(complexData.filter(complexData("s")("key") === 1).count() == 1) + } }