From 76d63461d2c512f5a6519d25dcaa14cfa8ec6468 Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Wed, 6 May 2015 11:20:01 +0800 Subject: [PATCH] [SPARK-7321][SQL] Add Column expression for conditional statements (if, case) --- python/pyspark/sql/dataframe.py | 18 +++++++++++++ .../scala/org/apache/spark/sql/Column.scala | 27 +++++++++++++++++++ .../spark/sql/ColumnExpressionSuite.scala | 18 +++++++++++++ 3 files changed, 63 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 24f370543def4..17eef7070eab2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1412,6 +1412,24 @@ def between(self, lowerBound, upperBound): """ return (self >= lowerBound) & (self <= upperBound) + @ignore_unicode_prefix + def when(self, whenExpr, thenExpr): + """ A case when otherwise expression.. + >>> df.select(df.age.when(2, 3).otherwise(4).alias("age")).collect() + [Row(age=3), Row(age=4)] + >>> df.select(df.age.when(2, 3).alias("age")).collect() + [Row(age=3), Row(age=None)] + >>> df.select(df.age.otherwise(4).alias("age")).collect() + [Row(age=4), Row(age=4)] + """ + jc = self._jc.when(whenExpr, thenExpr) + return Column(jc) + + @ignore_unicode_prefix + def otherwise(self, elseExpr): + jc = self._jc.otherwise(elseExpr) + return Column(jc) + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c0503bf047052..afe0193a56f1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -295,6 +295,33 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def eqNullSafe(other: Any): Column = this <=> other + /** + * Case When Otherwise. + * {{{ + * people.select( people("age").when(18, "SELECTED").other("IGNORED") ) + * }}} + * + * @group expr_ops + */ + def when(whenExpr: Any, thenExpr: Any):Column = { + this.expr match { + case CaseWhen(branches: Seq[Expression]) => + val caseExpr = branches.head.asInstanceOf[EqualNullSafe].left + CaseWhen(branches ++ Seq((caseExpr <=> whenExpr).expr, lit(thenExpr).expr)) + case _ => + CaseWhen(Seq((this <=> whenExpr).expr, lit(thenExpr).expr)) + } + } + + def otherwise(elseExpr: Any):Column = { + this.expr match { + case CaseWhen(branches: Seq[Expression]) => + CaseWhen(branches :+ lit(elseExpr).expr) + case _ => + CaseWhen(Seq(lit(true).expr, lit(elseExpr).expr)) + } + } + /** * True if the current column is between the lower bound and upper bound, inclusive. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 3c1ad656fc855..26997c39224c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -257,6 +257,24 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) } + test("SPARK-7321 case") { + val testData = (1 to 3).map(i => TestData(i, i.toString)).toDF() + checkAnswer( + testData.select($"key".when(1, -1).when(2, -2).otherwise(0)), + Seq(Row(-1), Row(-2), Row(0)) + ) + + checkAnswer( + testData.select($"key".when(1, -1).when(2, -2)), + Seq(Row(-1), Row(-2), Row(null)) + ) + + checkAnswer( + testData.select($"key".otherwise(0)), + Seq(Row(0), Row(0), Row(0)) + ) + } + test("sqrt") { checkAnswer( testData.select(sqrt('key)).orderBy('key.asc),