Skip to content

Commit

Permalink
add apply method and test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed May 8, 2015
1 parent 8df6199 commit f515d69
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ def __init__(self, jc):

# container operators
__contains__ = _bin_op("contains")
__getitem__ = _bin_op("getItem")
__getitem__ = _bin_op("apply")

# bitwise operators
bitwiseOR = _bin_op("bitwiseOR")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,31 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
checkEvaluation(resolveGetField('c.struct(typeS).at(2).getField("a")), "aa", row)
}

test("error message of GetField") {
val structType = StructType(StructField("a", StringType, true) :: Nil)
val arrayStructType = ArrayType(structType)
val arrayType = ArrayType(StringType)
val otherType = StringType

def checkErrorMessage(
childDataType: DataType,
fieldDataType: DataType,
errorMesage: String): Unit = {
val e = intercept[org.apache.spark.sql.AnalysisException] {
GetField(
Literal.create(null, childDataType),
Literal.create(null, fieldDataType),
_ == _)
}
assert(e.getMessage().contains(errorMesage))
}

checkErrorMessage(structType, IntegerType, "Field name should be String Literal")
checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal")
checkErrorMessage(arrayType, StringType, "Array index should be integral type")
checkErrorMessage(otherType, StringType, "Can't get field on")
}

test("arithmetic") {
val row = create_row(1, 2, 3, null)
val c1 = 'a.int.at(0)
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ class Column(protected[sql] val expr: Expression) extends Logging {

override def hashCode: Int = this.expr.hashCode

/**
* An expression that gets an item at a position out of an [[ArrayType]],
* or gets a value by key in a [[MapType]],
* or gets a field by name in a [[StructType]],
* or gets an array of fields by name in an array of [[StructType]].
*
* @group expr_ops
*/
def apply(field: Any): Column = UnresolvedGetField(expr, Literal(field))

/**
* Unary minus, i.e. negate the expression.
* {{{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ class DataFrameSuite extends QueryTest {
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1)
}.toSeq)
assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol"))
assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
}

test("replace column using withColumn") {
Expand Down Expand Up @@ -484,7 +484,7 @@ class DataFrameSuite extends QueryTest {
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1)
}.toSeq)
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
}

test("randomSplit") {
Expand Down Expand Up @@ -593,4 +593,10 @@ class DataFrameSuite extends QueryTest {
Row(new java.math.BigDecimal(2.0)))
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
}

test("SPARK-7133: Implement struct, array, and map field accessor") {
assert(complexData.filter(complexData("a")(0) === 2).count() == 1)
assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
}
}

0 comments on commit f515d69

Please sign in to comment.